pickel-cancer-rick/project-cancer-classificati...

511 lines
102 KiB
Plaintext
Raw Normal View History

2024-01-04 14:47:29 +01:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Laden der Rohdaten"
]
},
{
"cell_type": "code",
2024-01-05 13:19:38 +01:00
"execution_count": 1,
2024-01-04 14:47:29 +01:00
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
"# Laden der 'kirp' Liste aus der Pickle-Datei\n",
"with open('rick.pickle', 'rb') as f:\n",
" data_frame = pickle.load(f)"
]
},
2024-01-04 15:06:10 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Aktiviere Cuda Support"
]
},
{
"cell_type": "code",
2024-01-05 13:19:38 +01:00
"execution_count": 2,
2024-01-04 15:06:10 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CUDA is available on your system.\n"
]
}
],
"source": [
"import torch\n",
"device = \"cpu\"\n",
"if torch.cuda.is_available():\n",
" print(\"CUDA is available on your system.\")\n",
" device = \"cuda\"\n",
"else:\n",
" print(\"CUDA is not available on your system.\")"
]
},
2024-01-04 14:47:29 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# PCA Klasse zu Reduktion der Dimensionen"
]
},
{
"cell_type": "code",
2024-01-05 13:19:38 +01:00
"execution_count": 3,
2024-01-04 14:47:29 +01:00
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import Dataset\n",
"import torch\n",
"import pandas as pd\n",
"from sklearn.preprocessing import LabelEncoder\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.model_selection import train_test_split\n",
"from typing import List, Tuple\n",
"\n",
"\n",
"class GenomeDataset(Dataset):\n",
" \"\"\"\n",
" Eine benutzerdefinierte Dataset-Klasse, die für die Handhabung von Genomdaten konzipiert ist.\n",
" Diese Klasse wendet eine Principal Component Analysis (PCA) auf die Frequenzen der Genome an\n",
" und teilt den Datensatz in Trainings- und Validierungsteile auf.\n",
"\n",
" Attributes:\n",
" dataframe (pd.DataFrame): Ein Pandas DataFrame, der die initialen Daten enthält.\n",
" train_df (pd.DataFrame): Ein DataFrame, der den Trainingsdatensatz nach der Anwendung von PCA und der Aufteilung enthält.\n",
" val_df (pd.DataFrame): Ein DataFrame, der den Validierungsdatensatz nach der Anwendung von PCA und der Aufteilung enthält.\n",
"\n",
" Methods:\n",
" __init__(self, dataframe, n_pca_components=1034, train_size=0.8, split_random_state=42):\n",
" Konstruktor für die GenomeDataset Klasse.\n",
" _do_PCA(self, frequencies, n_components=1034):\n",
" Wendet PCA auf die gegebenen Frequenzen an.\n",
" _split_dataset(self, train_size=0.8, random_state=42):\n",
" Teilt den DataFrame in Trainings- und Validierungsdatensätze auf.\n",
" __getitem__(self, index):\n",
" Gibt ein Tupel aus transformierten Frequenzen und dem zugehörigen Krebstyp für einen gegebenen Index zurück.\n",
" __len__(self):\n",
" Gibt die Gesamtlänge der kombinierten Trainings- und Validierungsdatensätze zurück.\n",
" \"\"\"\n",
"\n",
" def __init__(self, dataframe: pd.DataFrame, n_pca_components: int = 1034, train_size: float = 0.8, split_random_state: int = 42):\n",
" \"\"\"\n",
" Konstruktor für die GenomeDataset Klasse.\n",
"\n",
" Parameters:\n",
" dataframe (pd.DataFrame): Der DataFrame, der die Genome Frequenzen und Krebsarten enthält.\n",
" n_pca_components (int): Die Anzahl der PCA-Komponenten, auf die reduziert werden soll. Standardwert ist 1034.\n",
" train_size (float): Der Anteil der Daten, der als Trainingsdaten verwendet werden soll. Standardwert ist 0.8.\n",
" split_random_state (int): Der Zufalls-Saatwert, der für die Aufteilung des Datensatzes verwendet wird. Standardwert ist 42.\n",
" \"\"\"\n",
" self.dataframe = dataframe\n",
"\n",
" # Umwandlung der Krebsarten in numerische Werte\n",
" self.label_encoder = LabelEncoder()\n",
" self.dataframe['encoded_cancer_type'] = self.label_encoder.fit_transform(dataframe['cancer_type'])\n",
"\n",
" # Anwenden der PCA auf die Frequenzen\n",
" self.dataframe['pca_frequencies'] = self._do_PCA(self.dataframe['genome_frequencies'].tolist(), n_pca_components)\n",
"\n",
" # Teilen des DataFrame in Trainings- und Validierungsdatensatz\n",
" self._split_dataset(train_size=train_size, random_state=split_random_state)\n",
"\n",
" def transform_datapoint(self, datapoint: List[float]) -> List[float]:\n",
" \"\"\"\n",
" Transformiert einen einzelnen Datenpunkt durch Standardisierung und Anwendung der PCA.\n",
"\n",
" Diese Methode nimmt einen rohen Datenpunkt (eine Liste von Frequenzen), standardisiert ihn mit dem \n",
" zuvor angepassten Scaler und wendet dann die PCA-Transformation an, um ihn in den reduzierten \n",
" Feature-Raum zu überführen, der für das Training des Modells verwendet wurde.\n",
"\n",
" Parameters:\n",
" datapoint (List[float]): Ein roher Datenpunkt, bestehend aus einer Liste von Frequenzen.\n",
"\n",
" Returns:\n",
" List[float]: Der transformierte Datenpunkt, nach Anwendung der Standardisierung und der PCA.\n",
" \"\"\"\n",
" # Standardisierung des Datenpunkts\n",
" scaled_data_point = self.scaler.transform([datapoint])\n",
"\n",
" # PCA-Transformation des standardisierten Datenpunkts\n",
" pca_transformed_point = self.pca.transform(scaled_data_point)\n",
"\n",
" return pca_transformed_point.tolist()\n",
"\n",
" def _do_PCA(self, frequencies: List[List[float]], n_components: int = 1034) -> List[List[float]]:\n",
" \"\"\"\n",
" Wendet PCA auf die gegebenen Frequenzen an.\n",
"\n",
" Parameters:\n",
" frequencies (List[List[float]]): Die Liste der Frequenzen, auf die die PCA angewendet werden soll.\n",
" n_components (int): Die Anzahl der Komponenten für die PCA. Standardwert ist 1034.\n",
"\n",
" Returns:\n",
" List[List[float]]: Eine Liste von Listen, die die transformierten Frequenzen nach der PCA darstellt.\n",
" \"\"\"\n",
"\n",
" # Standardisieren der Frequenzen\n",
" self.scaler = StandardScaler()\n",
" scaled_frequencies = self.scaler.fit_transform(frequencies)\n",
"\n",
" # PCA-Instanz erstellen und auf die gewünschte Anzahl von Komponenten reduzieren\n",
" self.pca = PCA(n_components=n_components)\n",
"\n",
" # PCA auf die Frequenzen anwenden\n",
" pca_result = self.pca.fit_transform(scaled_frequencies)\n",
"\n",
" return pca_result.tolist()\n",
"\n",
" def _split_dataset(self, train_size: float = 0.8, random_state: int = 42):\n",
" \"\"\"\n",
" Teilt den DataFrame in Trainings- und Validierungsdatensätze auf.\n",
"\n",
" Parameters:\n",
" train_size (float): Der Anteil der Daten, der als Trainingsdaten verwendet werden soll.\n",
" random_state (int): Der Zufalls-Saatwert, der für die Aufteilung des Datensatzes verwendet wird.\n",
" \"\"\"\n",
"\n",
" class SplittedDataset(Dataset):\n",
" def __init__(self, dataframe):\n",
" self.dataframe = dataframe\n",
"\n",
" # Umwandlung der Genome Frequenzen in Tensoren\n",
" self.genome_frequencies = torch.tensor(dataframe['pca_frequencies'].tolist(), dtype=torch.float32)\n",
"\n",
" # Umwandlung der Krebsarten in numerische Werte\n",
" self.label_encoder = LabelEncoder()\n",
" self.cancer_types = torch.tensor(dataframe['encoded_cancer_type'].tolist(), dtype=torch.long)\n",
"\n",
" def __getitem__(self, index):\n",
" # Rückgabe eines Tupels aus Genome Frequenzen und dem entsprechenden Krebstyp\n",
" return self.genome_frequencies[index], self.cancer_types[index]\n",
"\n",
" def __len__(self):\n",
" return len(self.dataframe)\n",
"\n",
" # Teilen des DataFrame in Trainings- und Validierungsdatensatz\n",
" train_df, val_df = train_test_split(self.dataframe, train_size=train_size, random_state=random_state)\n",
" self.train_df = SplittedDataset(train_df)\n",
" self.val_df = SplittedDataset(val_df)\n",
"\n",
"\n",
" def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:\n",
" \"\"\"\n",
" Gibt ein Tupel aus transformierten Frequenzen und dem entsprechenden Krebstyp für einen gegebenen Index zurück.\n",
"\n",
" Parameters:\n",
" index (int): Der Index des zu abrufenden Datenelements.\n",
"\n",
" Returns:\n",
" Tuple[torch.Tensor, int]: Ein Tupel, bestehend aus einem Tensor der transformierten Frequenzen und dem zugehörigen Krebstyp.\n",
" \"\"\"\n",
"\n",
" print(self.train_df.shape)\n",
" print(self.val_df.shape)\n",
" \n",
" if index < len(self.train_df):\n",
" row = self.train_df.iloc[index]\n",
" else:\n",
" row = self.val_df.iloc[len(self.train_df) - index]\n",
"\n",
" pca_frequencies_tensor = torch.tensor(row['pca_frequencies'], dtype=torch.float32)\n",
" cancer_type = row['encoded_cancer_type']\n",
"\n",
" return pca_frequencies_tensor, cancer_type\n",
"\n",
" def __len__(self) -> int:\n",
" \"\"\"\n",
" Gibt die Gesamtlänge der kombinierten Trainings- und Validierungsdatensätze zurück.\n",
"\n",
" Returns:\n",
" int: Die Länge der kombinierten Datensätze.\n",
" \"\"\"\n",
" \n",
" return len(self.train_df) + len(self.val_df)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Definition des neuronalen Netzes"
]
},
{
"cell_type": "code",
2024-01-05 13:19:38 +01:00
"execution_count": 40,
2024-01-04 14:47:29 +01:00
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import torch.nn.functional as F\n",
"\n",
"class CancerClassifierNN(nn.Module):\n",
" \"\"\"\n",
" Eine benutzerdefinierte neuronale Netzwerkklassifikator-Klasse für die Krebsklassifikation.\n",
"\n",
" Diese Klasse definiert ein mehrschichtiges Perzeptron (MLP), das für die Klassifizierung von Krebsarten\n",
" anhand genetischer Frequenzdaten verwendet wird.\n",
"\n",
" Attributes:\n",
" fc1 (nn.Linear): Die erste lineare Schicht des Netzwerks.\n",
" fc2 (nn.Linear): Die zweite lineare Schicht des Netzwerks.\n",
" fc3 (nn.Linear): Die dritte lineare Schicht des Netzwerks.\n",
" fc4 (nn.Linear): Die Ausgabeschicht des Netzwerks.\n",
" dropout (nn.Dropout): Ein Dropout-Layer zur Vermeidung von Overfitting.\n",
"\n",
" Methods:\n",
" __init__(self, input_size: int, num_classes: int):\n",
" Konstruktor für die CancerClassifierNN Klasse.\n",
" forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" Definiert den Vorwärtsdurchlauf des Netzwerks.\n",
" \"\"\"\n",
"\n",
" def __init__(self, input_size: int, num_classes: int):\n",
" \"\"\"\n",
" Konstruktor für die CancerClassifierNN Klasse.\n",
"\n",
" Parameters:\n",
" input_size (int): Die Größe des Input-Features.\n",
" num_classes (int): Die Anzahl der Zielklassen.\n",
" \"\"\"\n",
" super(CancerClassifierNN, self).__init__()\n",
" # Definieren der Schichten\n",
2024-01-05 13:19:38 +01:00
" self.fc1 = nn.Linear(input_size, input_size) # Eingabeschicht\n",
" self.fc2 = nn.Linear(input_size, input_size//2) # Versteckte Schicht\n",
" self.fc3 = nn.Linear(input_size//2, input_size//4) # Weitere versteckte Schicht\n",
" self.fc4 = nn.Linear(input_size//4, num_classes) # Ausgabeschicht\n",
2024-01-04 14:47:29 +01:00
" self.dropout = nn.Dropout(p=0.5) # Dropout\n",
"\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" \"\"\"\n",
" Definiert den Vorwärtsdurchlauf des Netzwerks.\n",
"\n",
" Parameters:\n",
" x (torch.Tensor): Der Input-Tensor für das Netzwerk.\n",
"\n",
" Returns:\n",
" torch.Tensor: Der Output-Tensor nach dem Durchlauf durch das Netzwerk.\n",
" \"\"\"\n",
" x = F.relu(self.fc1(x))\n",
" x = self.dropout(x)\n",
" x = F.relu(self.fc2(x))\n",
" x = self.dropout(x)\n",
" x = F.relu(self.fc3(x))\n",
2024-01-05 13:19:38 +01:00
" x = self.dropout(x)\n",
" x = torch.softmax(self.fc4(x), dim=1) # Oder F.log_softmax(x, dim=1) für Mehrklassenklassifikation\n",
2024-01-04 14:47:29 +01:00
" return x"
]
},
{
"cell_type": "code",
2024-01-05 13:19:38 +01:00
"execution_count": 99,
2024-01-04 14:47:29 +01:00
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
2024-01-05 13:19:38 +01:00
"N_COMPONENTS = 96\n",
2024-01-04 15:15:24 +01:00
"\n",
2024-01-04 14:47:29 +01:00
"# Erstellen der Dataframe Klasse\n",
2024-01-04 15:15:24 +01:00
"genome_dataset = GenomeDataset(data_frame, n_pca_components=N_COMPONENTS)\n",
2024-01-04 14:47:29 +01:00
"train_dataset = genome_dataset.train_df\n",
"valid_dataset = genome_dataset.val_df\n",
"\n",
"# Annahme: input_size ist die Länge Ihrer Genome-Frequenzen und num_classes ist die Anzahl der Krebsarten\n",
2024-01-04 15:15:24 +01:00
"model = CancerClassifierNN(input_size=N_COMPONENTS, num_classes=3)\n",
2024-01-04 15:06:10 +01:00
"model.to(device=device)\n",
2024-01-04 14:47:29 +01:00
"\n",
"# Daten-Loader\n",
"train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)\n",
"valid_loader = DataLoader(dataset=valid_dataset, batch_size=64, shuffle=False)"
]
},
{
"cell_type": "code",
2024-01-05 13:19:38 +01:00
"execution_count": 100,
2024-01-04 14:47:29 +01:00
"metadata": {},
2024-01-04 15:15:24 +01:00
"outputs": [],
2024-01-04 14:47:29 +01:00
"source": [
"import torch.optim as optim\n",
"\n",
"# Verlustfunktion\n",
"criterion = nn.CrossEntropyLoss()\n",
"\n",
2024-01-05 13:19:38 +01:00
"learning_rate = 0.0003\n",
"\n",
2024-01-04 14:47:29 +01:00
"# Optimierer\n",
2024-01-05 13:19:38 +01:00
"optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n",
2024-01-04 14:47:29 +01:00
"\n",
"# Anzahl der Epochen\n",
2024-01-05 13:19:38 +01:00
"num_epochs = 1000\n"
2024-01-04 14:47:29 +01:00
]
},
{
"cell_type": "code",
2024-01-05 13:19:38 +01:00
"execution_count": 101,
2024-01-04 14:47:29 +01:00
"metadata": {},
"outputs": [
2024-01-04 15:06:10 +01:00
{
"data": {
2024-01-05 13:19:38 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA24AAAHWCAYAAAAPV9pWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAD190lEQVR4nOydd1hURxeHf0vvRUVARFTsgg17QxHFhmLvYtfYY4wlGktsMfYYW/QLGqPGbowNSyD2GruoqGDFLkiTtuf7Y9xl7+4Ci6wget7nuQ87c6ecmbn3cs89M2dkRERgGIZhGIZhGIZhPlkM8loAhmEYhmEYhmEYJnNYcWMYhmEYhmEYhvnEYcWNYRiGYRiGYRjmE4cVN4ZhGIZhGIZhmE8cVtwYhmEYhmEYhmE+cVhxYxiGYRiGYRiG+cRhxY1hGIZhGIZhGOYThxU3hmEYhmEYhmGYTxxW3BiGYRiGYRiGYT5xPkvFrU+fPihevPgH5Z02bRpkMpl+BfqMiIyMhEwmw9q1a3O9bplMhmnTpinDa9euhUwmQ2RkZJZ5ixcvjj59+ijDoaGhkMlkCA0N1bucjOBL7uNGjRqhUaNGeS3GJ0FOnhmKvPPnz9eLLDn535CRTFZWVpJni4LixYujdevWOarrY5GTZ2leoo/xyynqfccwDJOb5KriJpPJdDq+xBe9z4mRI0dCJpPhzp07GaaZNGkSZDIZrly5kouSMZ8TJ0+exLRp0xAdHZ3XojD5lISEBEybNi1b/3McHBywfv16tGrVCgDwv//9D4MHD/5IEuZPihcvnuX/eW3Kbnb5kPHTN/wcYhgmNzHKzcrWr18vCf/+++84dOiQRnz58uVzVM/q1ashl8s/KO/kyZMxYcKEHNX/pdOjRw8sXboUGzduxJQpU7Sm2bRpEzw9PVGpUqUPrqdXr17o2rUrTE1Ns523YcOGSExMhImJyQfXz+QtJ0+exPTp09GnTx/Y2dnltThMBri5uSExMRHGxsZ5LYrG/4aEhARMnz4dAHS2kFpaWqJnz57KcJcuXfQqY16Qk2epNhYvXoy4uDit53755RecOXMGtWvXzna5+hi/nJKYmAgjo/RXJ34OMQyTm+Sq4qb6zw4ATp8+jUOHDmnEq5OQkAALCwud68nJC4KRkZHkocxkn1q1aqFUqVLYtGmTVsXt1KlTiIiIwI8//pijegwNDWFoaPhBeQ0MDGBmZpaj+lWJj4+HpaWl3srLz7x7944VYkaJTCbT672WEz4F5VGfpKamQi6X5/h+y8mzVBsBAQFa4w8ePIizZ8+iTZs2GDJkSLbL/RTG71O5lhmG+TL55Na4NWrUCB4eHrhw4QIaNmwICwsLfPfddwCAv/76C61atUKRIkVgamoKd3d3zJgxA2lpaZIy1OfBq65J+PXXX+Hu7g5TU1PUqFED586dk+TVtsZNJpNh+PDh2LVrFzw8PGBqaoqKFSviwIEDGvKHhoaievXqMDMzg7u7O1atWqW1zEOHDqF+/fqws7ODlZUVypYtq2xndsloLZG2tSV9+vSBlZUVHj9+jICAAFhZWcHBwQFjx47V6Mfo6Gj06dMHtra2sLOzQ2BgoM7TQXr06IGbN2/iv//+0zi3ceNGyGQydOvWDcnJyZgyZQq8vLxga2sLS0tLNGjQACEhIVnWoW1dBhFh5syZKFq0KCwsLNC4cWNcv35dI29GfXbmzBk0b94ctra2sLCwgLe3N06cOCFJoxjPGzduoHv37rC3t0f9+vUBZLy2KSfXJABs3boVFSpUgJmZGTw8PLBz506t6z3+/PNPeHl5wdraGjY2NvD09MSSJUsAAOfPn4dMJsO6des0yg8ODoZMJsOePXuUcY8fP0a/fv3g6OiovOZ/++03rf34559/YvLkyXBxcYGFhQXevn2rUQegudZQgbZ+W7p0KSpWrAgLCwvY29ujevXq2LhxIwAxBt9++y0AoESJEsrpV5mt0dG1bkWbtmzZglmzZqFo0aIwMzNDkyZNtE7/VYyfubk5atasiWPHjmUog668evUKvXr1go2NjfLeu3z5sta1Yjdv3kTHjh1RoEABmJmZoXr16ti9e7ckjeJeOXHiBMaMGQMHBwdYWlqiXbt2ePHihSStrs9ZXfszozVuul7T6hARBg0aBBMTE+zYsUMZ/8cff8DLywvm5uYoUKAAunbtiocPH0ryqpYfGRkJBwcHAMD06dOV11Bm65cyWgOd2RqxgwcPokqVKjAzM0OFChUkMiuIjo7G6NGj4erqClNTU5QqVQpz586VWJdUnxmLFy9WPjNu3LiRobxJSUn4+uuv4eDgAGtra7Rp0waPHj3SWf79+/ejQYMGsLS0hLW1NVq1aqX1eaoLT58+Ra9eveDi4oKgoCCN87kxfop2Hj9+HCNHjoSDgwPs7OwwePBgJCcnIzo6Gr1794a9vT3s7e0xbtw4EJGkDNU6PuQ5xDAMkxM+SdPSq1ev0KJFC3Tt2hU9e/aEo6MjAPHQtbKywpgxY2BlZYV//vkHU6ZMwdu3bzFv3rwsy924cSNiY2MxePBgyGQy/PTTT2jfvj3u3buX5Ze848ePY8eOHRg6dCisra3x888/o0OHDnjw4AEKFiwIALh48SKaN28OZ2dnTJ8+HWlpafjhhx+U/1wUXL9+Ha1bt0alSpXwww8/wNTUFHfu3NFQED4WaWlp8PPzQ61atTB//nwcPnwYCxYsgLu7O7766isA4uWobdu2OH78OIYMGYLy5ctj586dCAwM1KmOHj16YPr06di4cSOqVasmqXvLli1o0KABihUrhpcvX2LNmjXo1q0bBg4ciNjYWPzvf/+Dn58fzp49iypVqmSrbVOmTMHMmTPRsmVLtGzZEv/99x+aNWuG5OTkLPP+888/aNGiBby8vDB16lQYGBggKCgIPj4+OHbsGGrWrClJ36lTJ5QuXRqzZ8/W+OeuK7pck3v37kWXLl3g6emJOXPm4M2bN+jfvz9cXFwkZR06dAjdunVDkyZNMHfuXABAWFgYTpw4gVGjRqF69eooWbIktmzZojGOmzdvhr29Pfz8/AAAz549Q+3atZUfLRwcHLB//370798fb9++xejRoyX5Z8yYARMTE4wdOxZJSUk5tgCsXr0aI0eORMeOHTFq1Ci8e/cOV65cwZkzZ9C9e3e0b98et2/fxqZNm7Bo0SIUKlQIADTutZzw448/wsDAAGPHjkVMTAx++ukn9OjRA2fOnFGmUaxvqlu3LkaPHo179+6hTZs2KFCgAFxdXT+oXrlcDn9/f5w9exZfffUVypUrh7/++kvrvXf9+nXUq1cPLi4umDBhAiwtLbFlyxYEBARg+/btaNeunST9iBEjYG9vj6lTpyIyMhKLFy/G8OHDsXnzZmWanD5ndUHXa1qdtLQ09OvXD5s3b8bOnTuV68xmzZqF77//Hp07d8aAAQPw4sULLF26FA0bNsTFixe1TmFzcHDAihUr8NVXX6Fdu3Zo3749AORo+rY64eHh6NKlC4YMGYLAwEAEBQWhU6dOOHDgAJo2bQpAzCjx9vbG48ePMXjwYBQrVgwnT57ExIkTERUVhcWLF0vKDAoKwrt37zBo0CCYmpqiQIECGdY/YMAA/PHHH+jevTvq1q2Lf/75R9lnWbF+/XoEBgbCz88Pc+fORUJCAlasWIH69evj4sWL2XISIpfL0bNnT7x69QohISEaMuf2+I0YMQJOTk6YPn06Tp8+jV9//RV2dnY4efIkihUrhtmzZ2Pfvn2YN28ePDw80Lt3b63l5MZziGEYRgLlIcOGDSN1Eby9vQkArVy5UiN9QkKCRtzgwYPJwsKC3r17p4wLDAwkNzc3ZTgiIoIAUMGCBen169fK+L/++osA0N9//62Mmzp1qoZMAMjExITu3LmjjLt8+TIBoKVLlyrj/P39ycLCgh4/fqyMCw8PJyMjI0mZixYtIgD04sULrf2SXUJCQggAhYSESOIV7Q4KClLGBQYGEgD64YcfJGmrVq1KXl5eyvCuXbsIAP3000/KuNTUVGrQoIFGmRlRo0YNKlq0KKWlpSnjDhw4QABo1apVyjKTkpI
2024-01-04 15:06:10 +01:00
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
2024-01-04 14:47:29 +01:00
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-01-05 13:19:38 +01:00
"Epoch [1000/1000], Trainingsverlust: 0.5610, Trainingsgenauigkeit: 0.9915, Validierungsverlust: 0.5992, Validierungsgenauigkeit: 0.9565\n"
2024-01-04 14:47:29 +01:00
]
}
],
"source": [
2024-01-04 15:06:10 +01:00
"from IPython.display import clear_output\n",
"import matplotlib.pyplot as plt\n",
"\n",
2024-01-04 14:47:29 +01:00
"# Listen, um Verluste zu speichern\n",
"train_losses = []\n",
"valid_losses = []\n",
2024-01-04 15:06:10 +01:00
"train_accuracies = []\n",
"valid_accuracies = []\n",
2024-01-04 14:47:29 +01:00
"\n",
"for epoch in range(num_epochs):\n",
" model.train()\n",
" train_loss = 0.0\n",
2024-01-04 15:06:10 +01:00
" correct_predictions = 0\n",
" total_predictions = 0\n",
"\n",
2024-01-04 14:47:29 +01:00
" for i, (inputs, labels) in enumerate(train_loader):\n",
2024-01-04 15:06:10 +01:00
" inputs, labels = inputs.to(device), labels.to(device)\n",
2024-01-04 14:47:29 +01:00
" optimizer.zero_grad()\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
" train_loss += loss.item()\n",
"\n",
2024-01-04 15:06:10 +01:00
" # Berechnen der Genauigkeit\n",
" _, predicted = torch.max(outputs, 1)\n",
" correct_predictions += (predicted == labels).sum().item()\n",
" total_predictions += labels.size(0)\n",
"\n",
" # Durchschnittlicher Trainingsverlust und Genauigkeit\n",
2024-01-04 14:47:29 +01:00
" train_loss /= len(train_loader)\n",
2024-01-04 15:06:10 +01:00
" train_accuracy = correct_predictions / total_predictions\n",
2024-01-04 14:47:29 +01:00
" train_losses.append(train_loss)\n",
2024-01-04 15:06:10 +01:00
" train_accuracies.append(train_accuracy)\n",
2024-01-04 14:47:29 +01:00
"\n",
2024-01-04 15:06:10 +01:00
" # Validierungsverlust und Genauigkeit\n",
2024-01-04 14:47:29 +01:00
" model.eval()\n",
" valid_loss = 0.0\n",
2024-01-04 15:06:10 +01:00
" correct_predictions = 0\n",
" total_predictions = 0\n",
"\n",
2024-01-04 14:47:29 +01:00
" with torch.no_grad():\n",
" for inputs, labels in valid_loader:\n",
2024-01-04 15:06:10 +01:00
" inputs, labels = inputs.to(device), labels.to(device)\n",
2024-01-04 14:47:29 +01:00
" outputs = model(inputs)\n",
" loss = criterion(outputs, labels)\n",
" valid_loss += loss.item()\n",
"\n",
2024-01-04 15:06:10 +01:00
" # Berechnen der Genauigkeit\n",
" _, predicted = torch.max(outputs, 1)\n",
" correct_predictions += (predicted == labels).sum().item()\n",
" total_predictions += labels.size(0)\n",
"\n",
" # Durchschnittlicher Validierungsverlust und Genauigkeit\n",
2024-01-04 14:47:29 +01:00
" valid_loss /= len(valid_loader)\n",
2024-01-04 15:06:10 +01:00
" valid_accuracy = correct_predictions / total_predictions\n",
2024-01-04 14:47:29 +01:00
" valid_losses.append(valid_loss)\n",
2024-01-04 15:06:10 +01:00
" valid_accuracies.append(valid_accuracy)\n",
"\n",
2024-01-05 13:19:38 +01:00
" if valid_accuracy >= 0.999:\n",
" break\n",
"\n",
2024-01-04 15:06:10 +01:00
"\n",
" # Aktualisieren des Graphen\n",
" clear_output(wait=True)\n",
" fig, ax1 = plt.subplots()\n",
"\n",
" # Zeichnen der Verlustkurven\n",
" ax1.plot(train_losses, label='Trainingsverlust', color='r')\n",
" ax1.plot(valid_losses, label='Validierungsverlust', color='b')\n",
" ax1.set_xlabel('Epochen')\n",
" ax1.set_ylabel('Verlust', color='g')\n",
" ax1.tick_params(axis='y', labelcolor='g')\n",
"\n",
" # Zweite y-Achse für die Genauigkeit\n",
" ax2 = ax1.twinx()\n",
" ax2.plot(train_accuracies, label='Trainingsgenauigkeit', color='r', linestyle='dashed')\n",
" ax2.plot(valid_accuracies, label='Validierungsgenauigkeit', color='b', linestyle='dashed')\n",
" ax2.set_ylabel('Genauigkeit', color='g')\n",
" ax2.tick_params(axis='y', labelcolor='g')\n",
"\n",
" # Titel und Legende\n",
2024-01-05 13:19:38 +01:00
" plt.title(f'Trainings- und Validierungsverlust und -genauigkeit über die Zeit mit \\n{N_COMPONENTS}-Hauptkomponenten, Lernrate: {learning_rate}')\n",
2024-01-04 15:06:10 +01:00
" fig.tight_layout()\n",
2024-01-05 13:19:38 +01:00
"\n",
" # Legende außerhalb des Graphen\n",
" ax1.legend(loc='upper left', bbox_to_anchor=(1.15, 1))\n",
" ax2.legend(loc='upper left', bbox_to_anchor=(1.15, 0.85))\n",
2024-01-04 15:06:10 +01:00
"\n",
" plt.show()\n",
2024-01-04 14:47:29 +01:00
"\n",
2024-01-05 13:19:38 +01:00
" print(f'Epoch [{epoch+1}/{num_epochs}], Trainingsverlust: {train_loss:.4f}, Trainingsgenauigkeit: {train_accuracies[-1]:.4f}, Validierungsverlust: {valid_loss:.4f}, Validierungsgenauigkeit: {valid_accuracies[-1]:.4f}')"
2024-01-04 14:47:29 +01:00
]
2024-01-04 15:15:24 +01:00
},
2024-01-05 13:19:38 +01:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
2024-01-04 15:15:24 +01:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
2024-01-04 14:47:29 +01:00
}
],
"metadata": {
"kernelspec": {
"display_name": "rl",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2024-01-04 15:06:10 +01:00
"version": "3.9.13"
2024-01-04 14:47:29 +01:00
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}