pushibus
This commit is contained in:
parent
daefea4207
commit
f1cad86475
|
@ -242,6 +242,7 @@
|
|||
"import torch.nn as nn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"from sklearn.metrics import confusion_matrix\n",
|
||||
"\n",
|
||||
"class CancerClassifierNN(nn.Module):\n",
|
||||
" \"\"\"\n",
|
||||
|
@ -297,7 +298,31 @@
|
|||
" x = F.relu(self.fc3(x))\n",
|
||||
" x = self.dropout(x)\n",
|
||||
" x = torch.softmax(self.fc4(x), dim=1) # Oder F.log_softmax(x, dim=1) für Mehrklassenklassifikation\n",
|
||||
" return x"
|
||||
" return x\n",
|
||||
" \n",
|
||||
" def calculate_confusion_matrix(self, model, dataset_loader):\n",
|
||||
" \"\"\"\n",
|
||||
" Berechnet die Konfusionsmatrix für das gegebene Modell und den Datensatz.\n",
|
||||
"\n",
|
||||
" Parameters:\n",
|
||||
" model (torch.nn.Module): Das PyTorch-Modell.\n",
|
||||
" dataset_loader (torch.utils.data.DataLoader): Der DataLoader für den Datensatz.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" np.array: Die Konfusionsmatrix.\n",
|
||||
" \"\"\"\n",
|
||||
" model.eval()\n",
|
||||
" all_preds = []\n",
|
||||
" all_targets = []\n",
|
||||
"\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for data, target in dataset_loader:\n",
|
||||
" outputs = model(data)\n",
|
||||
" _, preds = torch.max(outputs, 1)\n",
|
||||
" all_preds.extend(preds.cpu().numpy())\n",
|
||||
" all_targets.extend(target.cpu().numpy())\n",
|
||||
"\n",
|
||||
" return confusion_matrix(all_targets, all_preds)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue