This commit is contained in:
Meiko Remiorz 2024-01-08 14:20:49 +01:00
parent daefea4207
commit f1cad86475
1 changed files with 26 additions and 1 deletions

View File

@ -242,6 +242,7 @@
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import torch.nn.functional as F\n",
"from sklearn.metrics import confusion_matrix\n",
"\n",
"class CancerClassifierNN(nn.Module):\n",
" \"\"\"\n",
@ -297,7 +298,31 @@
" x = F.relu(self.fc3(x))\n",
" x = self.dropout(x)\n",
" x = torch.softmax(self.fc4(x), dim=1) # Oder F.log_softmax(x, dim=1) für Mehrklassenklassifikation\n",
" return x"
" return x\n",
" \n",
" def calculate_confusion_matrix(self, model, dataset_loader):\n",
" \"\"\"\n",
" Berechnet die Konfusionsmatrix für das gegebene Modell und den Datensatz.\n",
"\n",
" Parameters:\n",
" model (torch.nn.Module): Das PyTorch-Modell.\n",
" dataset_loader (torch.utils.data.DataLoader): Der DataLoader für den Datensatz.\n",
"\n",
" Returns:\n",
" np.array: Die Konfusionsmatrix.\n",
" \"\"\"\n",
" model.eval()\n",
" all_preds = []\n",
" all_targets = []\n",
"\n",
" with torch.no_grad():\n",
" for data, target in dataset_loader:\n",
" outputs = model(data)\n",
" _, preds = torch.max(outputs, 1)\n",
" all_preds.extend(preds.cpu().numpy())\n",
" all_targets.extend(target.cpu().numpy())\n",
"\n",
" return confusion_matrix(all_targets, all_preds)"
]
},
{