pushibus
This commit is contained in:
parent
daefea4207
commit
f1cad86475
|
@ -242,6 +242,7 @@
|
||||||
"import torch.nn as nn\n",
|
"import torch.nn as nn\n",
|
||||||
"import torch.optim as optim\n",
|
"import torch.optim as optim\n",
|
||||||
"import torch.nn.functional as F\n",
|
"import torch.nn.functional as F\n",
|
||||||
|
"from sklearn.metrics import confusion_matrix\n",
|
||||||
"\n",
|
"\n",
|
||||||
"class CancerClassifierNN(nn.Module):\n",
|
"class CancerClassifierNN(nn.Module):\n",
|
||||||
" \"\"\"\n",
|
" \"\"\"\n",
|
||||||
|
@ -297,7 +298,31 @@
|
||||||
" x = F.relu(self.fc3(x))\n",
|
" x = F.relu(self.fc3(x))\n",
|
||||||
" x = self.dropout(x)\n",
|
" x = self.dropout(x)\n",
|
||||||
" x = torch.softmax(self.fc4(x), dim=1) # Oder F.log_softmax(x, dim=1) für Mehrklassenklassifikation\n",
|
" x = torch.softmax(self.fc4(x), dim=1) # Oder F.log_softmax(x, dim=1) für Mehrklassenklassifikation\n",
|
||||||
" return x"
|
" return x\n",
|
||||||
|
" \n",
|
||||||
|
" def calculate_confusion_matrix(self, model, dataset_loader):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Berechnet die Konfusionsmatrix für das gegebene Modell und den Datensatz.\n",
|
||||||
|
"\n",
|
||||||
|
" Parameters:\n",
|
||||||
|
" model (torch.nn.Module): Das PyTorch-Modell.\n",
|
||||||
|
" dataset_loader (torch.utils.data.DataLoader): Der DataLoader für den Datensatz.\n",
|
||||||
|
"\n",
|
||||||
|
" Returns:\n",
|
||||||
|
" np.array: Die Konfusionsmatrix.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" model.eval()\n",
|
||||||
|
" all_preds = []\n",
|
||||||
|
" all_targets = []\n",
|
||||||
|
"\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" for data, target in dataset_loader:\n",
|
||||||
|
" outputs = model(data)\n",
|
||||||
|
" _, preds = torch.max(outputs, 1)\n",
|
||||||
|
" all_preds.extend(preds.cpu().numpy())\n",
|
||||||
|
" all_targets.extend(target.cpu().numpy())\n",
|
||||||
|
"\n",
|
||||||
|
" return confusion_matrix(all_targets, all_preds)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
Loading…
Reference in New Issue