From f1cad86475f9b5844d63bb015eb5378951531834 Mon Sep 17 00:00:00 2001 From: Meiko Remiorz Date: Mon, 8 Jan 2024 14:20:49 +0100 Subject: [PATCH] pushibus --- project-cancer-classification-pca-nn.ipynb | 27 +++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/project-cancer-classification-pca-nn.ipynb b/project-cancer-classification-pca-nn.ipynb index 52f408a..454e8d2 100644 --- a/project-cancer-classification-pca-nn.ipynb +++ b/project-cancer-classification-pca-nn.ipynb @@ -242,6 +242,7 @@ "import torch.nn as nn\n", "import torch.optim as optim\n", "import torch.nn.functional as F\n", + "from sklearn.metrics import confusion_matrix\n", "\n", "class CancerClassifierNN(nn.Module):\n", " \"\"\"\n", @@ -297,7 +298,31 @@ " x = F.relu(self.fc3(x))\n", " x = self.dropout(x)\n", " x = torch.softmax(self.fc4(x), dim=1) # Oder F.log_softmax(x, dim=1) für Mehrklassenklassifikation\n", - " return x" + " return x\n", + " \n", + " def calculate_confusion_matrix(self, model, dataset_loader):\n", + " \"\"\"\n", + " Berechnet die Konfusionsmatrix für das gegebene Modell und den Datensatz.\n", + "\n", + " Parameters:\n", + " model (torch.nn.Module): Das PyTorch-Modell.\n", + " dataset_loader (torch.utils.data.DataLoader): Der DataLoader für den Datensatz.\n", + "\n", + " Returns:\n", + " np.array: Die Konfusionsmatrix.\n", + " \"\"\"\n", + " model.eval()\n", + " all_preds = []\n", + " all_targets = []\n", + "\n", + " with torch.no_grad():\n", + " for data, target in dataset_loader:\n", + " outputs = model(data)\n", + " _, preds = torch.max(outputs, 1)\n", + " all_preds.extend(preds.cpu().numpy())\n", + " all_targets.extend(target.cpu().numpy())\n", + "\n", + " return confusion_matrix(all_targets, all_preds)" ] }, {