pickel-cancer-rick/project-cancer-classificati...

1085 lines
100 KiB
Plaintext
Raw Permalink Normal View History

{
"cells": [
{
"cell_type": "markdown",
"id": "8bc02404-8cd1-46d9-8237-2d035ebb3e79",
"metadata": {},
"source": [
"# **[Project] Cancer Subtype Classification**"
]
},
{
"cell_type": "markdown",
"id": "0c5076f4",
"metadata": {},
"source": [
"# Introduction"
]
},
{
"cell_type": "markdown",
"id": "8a599748",
"metadata": {},
"source": [
"The [TCGA Kidney Cancers Dataset](https://archive.ics.uci.edu/dataset/892/tcga+kidney+cancers) is a bulk RNA-seq dataset that contains transcriptome profiles (i.e., gene expression quantification data) of patients diagnosed with three different subtypes of kidney cancers.\n",
"This dataset can be used to make predictions about the specific subtype of kidney cancers given the normalized transcriptome profile data.\n",
"\n",
"The normalized transcriptome profile data is given as **TPM** and **FPKM** for each gene.\n",
"\n",
"> TPM (Transcripts Per Million) and FPKM (Fragments Per Kilobase Million) are two common methods for quantifying gene expression in RNA sequencing data.\n",
"> They both aim to account for the differences in sequencing depth and transcript length when estimating gene expression levels.\n",
">\n",
"> **TPM** (Transcripts Per Million):\n",
"> - TPM is a measure of gene expression that normalizes for both library size (sequencing depth) and transcript length.\n",
"> - The main idea behind TPM is to express the abundance of a transcript relative to the total number of transcripts in a sample, scaled to one million.\n",
">\n",
"> **FPKM** (Fragments Per Kilobase Million):\n",
"> - FPKM is another method for quantifying gene expression, which is commonly used in older RNA-seq analysis pipelines. It's similar in concept to TPM but differs in the way it's calculated.\n",
"> - FPKM also normalizes for library size and transcript length, but it measures gene expression as the number of fragments (i.e., reads) per kilobase of exon model per million reads.\n",
">\n",
"> TPM is generally considered more robust to variations in library size, making it a preferred choice in many modern RNA-seq analysis workflows.\n",
"\n",
"We provide one dataset for each kidney cancer subtype:\n",
"\n",
"- [TCGA-KICH](https://portal.gdc.cancer.gov/projects/TCGA-KICH): kidney chromophobe (renal clear cell carcinoma)\n",
"- [TCGA-KIRC](https://portal.gdc.cancer.gov/projects/TCGA-KIRC): kidney renal clear cell carcinoma\n",
"- [TCGA-KIRP](https://portal.gdc.cancer.gov/projects/TCGA-KIRP): kidney renal papillary cell carcinoma\n",
"\n",
"> This and _much_ more data is openly available on the [NCI Genomic Data Commons (GDC) Data Portal](https://portal.gdc.cancer.gov/)."
]
},
{
"cell_type": "markdown",
"id": "16712787",
"metadata": {},
"source": [
"# Data access"
]
},
{
"cell_type": "markdown",
"id": "6421ef6c",
"metadata": {},
"source": [
"There are two ways to access the data: via the TNT homepage or the GDC Data Portal."
]
},
{
"cell_type": "markdown",
"id": "b977e8b8",
"metadata": {},
"source": [
"## Download from the TNT homepage (_recommended_)"
]
},
{
"cell_type": "markdown",
"id": "800fa7bd",
"metadata": {},
"source": [
"The download from the TNT homepage is straightforward:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "dda97b16",
"metadata": {
"tags": []
},
2024-01-04 13:00:47 +01:00
"outputs": [],
"source": [
2024-01-04 13:00:47 +01:00
"# ! wget http://www.tnt.uni-hannover.de/edu/vorlesungen/AMLG/data/project-cancer-classification.tar.gz\n",
"# ! tar -xzvf project-cancer-classification.tar.gz\n",
"# ! mv -v project-cancer-classification/ data/\n",
"# ! rm -v project-cancer-classification.tar.gz"
]
},
{
"cell_type": "markdown",
"id": "bc2db880",
"metadata": {},
"source": [
"In the `data/` folder you will now find many files in the [TSV format](https://en.wikipedia.org/wiki/Tab-separated_values) ([CSV](https://en.wikipedia.org/wiki/Comma-separated_values)-like with tabs as delimiter) containing the normalized transcriptome profile data.\n",
"\n",
"To start, you can read a TSV file into a [pandas](https://pandas.pydata.org) [`DataFrame`](pandas dataframe to dict) using the [`pandas.read_csv()`](https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html#pandas-read-csv) function with the `sep` parameter set to `\\t`:"
]
},
{
"cell_type": "markdown",
"id": "ed50d396-fe33-47a7-ad19-8eb975ef0fa5",
"metadata": {},
"source": [
"## Lesen der DNA-Sequenz Dateien und speichern in einer Datei"
]
},
{
"cell_type": "code",
2024-01-04 13:00:47 +01:00
"execution_count": 2,
"id": "2adae4ff",
"metadata": {
"tags": []
},
2024-01-04 13:00:47 +01:00
"outputs": [],
"source": [
2024-01-04 13:00:47 +01:00
"# import numpy as np\n",
"# import pandas as pd\n",
"# import pickle\n",
"\n",
"\n",
2024-01-04 13:00:47 +01:00
"# import os\n",
"# #'./data/tcga-kirp-geq'\n",
"\n",
"# labels = [\"kirp\", \"kirc\", \"kich\"] # Setzen Sie hier Ihren Ordnerpfad ein\n",
"# n_files = 0\n",
"# y = list()\n",
"# x = list()\n",
"\n",
"# rick = list()\n",
"# data = []\n",
"\n",
"# for l in labels:\n",
"# root_folder = f\"./data/tcga-{l}-geq\"\n",
"# for root, dirs, files in os.walk(root_folder):\n",
"# for file in files:\n",
"# if file.endswith('.tsv'):\n",
"# n_files += 1\n",
"# # Vollständiger Pfad zur Datei\n",
"# file_path = os.path.join(root, file)\n",
"# # Hier können Sie etwas mit der Datei machen, z.B. einlesen\n",
"# df = pd.read_csv(filepath_or_buffer=file_path, sep=\"\\t\", header=1)\n",
"# df = df['tpm_unstranded']\n",
"\n",
"# df = df[4:]\n",
"# df = np.array(df)\n",
"# rick.append(df)\n",
" \n",
2024-01-04 13:00:47 +01:00
"# data.append([df, l])\n",
"\n",
2024-01-04 13:00:47 +01:00
"# print(f\"Es wurden {n_files} Dateien eingelesen.\")\n",
"# #tsv_file_path = \"data/tcga-kich-geq/0ba21ef5-0829-422e-a674-d3817498c333/4868e8fc-e045-475a-a81d-ef43eabb7066.rna_seq.augmented_star_gene_counts.tsv\"\n",
"\n",
2024-01-04 13:00:47 +01:00
"# # Read the TSV file into a DataFrame\n",
"# #df = pd.read_csv(filepath_or_buffer=tsv_file_path, sep=\"\\t\", header=1)\n",
"\n",
2024-01-04 13:00:47 +01:00
"# # Display the first few rows of the DataFrame\n",
"# #print(df.head(n=20))\n",
"# #rick = np.array(rick)\n",
"\n",
2024-01-04 13:00:47 +01:00
"# # Speichern der 'kirp' Liste in einer Pickle-Datei\n",
"# #with open('rick.pickle', 'wb') as f:\n",
"# # pickle.dump(rick, f)\n"
]
},
{
"cell_type": "code",
2024-01-04 13:00:47 +01:00
"execution_count": 3,
"id": "dfe4f964-6068-46da-8103-194525086f01",
"metadata": {
"tags": []
},
2024-01-04 13:00:47 +01:00
"outputs": [],
"source": [
2024-01-04 13:00:47 +01:00
"# data_Frame = pd.DataFrame(data, columns=[\"genome_frequencies\", \"cancer_type\"])\n",
"# data_Frame.head()"
]
},
{
"cell_type": "code",
2024-01-04 13:00:47 +01:00
"execution_count": 4,
"id": "0f5cc92a-4485-4184-845e-116ea9a9776d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
2024-01-04 13:00:47 +01:00
"# # Speichern der Daten in einer lokalen Datei\n",
"# with open('rick.pickle', 'wb') as f:\n",
"# pickle.dump(data_Frame, f)"
]
},
{
"cell_type": "code",
2024-01-04 13:00:47 +01:00
"execution_count": 5,
"id": "b7b79958-baba-4630-9def-cf47afe43d9f",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import pickle\n",
"\n",
"# Laden der 'kirp' Liste aus der Pickle-Datei\n",
"with open('rick.pickle', 'rb') as f:\n",
" data_Frame = pickle.load(f)"
]
},
{
"cell_type": "code",
2024-01-04 13:00:47 +01:00
"execution_count": 6,
"id": "f6608b92-8ace-4a52-a3dc-70c578e56f0d",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>genome_frequencies</th>\n",
" <th>cancer_type</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>[20.331, 0.0, 25.1806, 1.1301, 0.4836, 7.3269,...</td>\n",
" <td>kirp</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>[37.0405, 0.5002, 77.4246, 4.2188, 1.0408, 29....</td>\n",
" <td>kirp</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>[45.4456, 0.0903, 74.9545, 4.843, 1.5188, 11.8...</td>\n",
" <td>kirp</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>[15.2345, 0.3393, 62.0003, 2.4412, 0.932, 2.66...</td>\n",
" <td>kirp</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>[35.0709, 0.2333, 62.8022, 2.8872, 1.0547, 18....</td>\n",
" <td>kirp</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" genome_frequencies cancer_type\n",
"0 [20.331, 0.0, 25.1806, 1.1301, 0.4836, 7.3269,... kirp\n",
"1 [37.0405, 0.5002, 77.4246, 4.2188, 1.0408, 29.... kirp\n",
"2 [45.4456, 0.0903, 74.9545, 4.843, 1.5188, 11.8... kirp\n",
"3 [15.2345, 0.3393, 62.0003, 2.4412, 0.932, 2.66... kirp\n",
"4 [35.0709, 0.2333, 62.8022, 2.8872, 1.0547, 18.... kirp"
]
},
2024-01-04 13:00:47 +01:00
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_Frame.head()"
]
},
{
"cell_type": "markdown",
"id": "c60cbf60-d904-4ee0-8f70-588bb109368b",
"metadata": {},
"source": [
"# Data preprocessing"
]
},
{
"cell_type": "markdown",
"id": "583e39c8-13ba-422e-9c39-9cf1c8d63d5b",
"metadata": {},
"source": [
"## Training set & validation set"
]
},
{
"cell_type": "code",
2024-01-04 13:00:47 +01:00
"execution_count": 7,
"id": "38695a70-86e9-4dd0-b622-33e3762372eb",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DataSet shape: (1034, 2)\n",
"Training set\n",
"------------\n",
"Dataframe shape: (827, 2)\n",
"Dataframe head:\n",
" genome_frequencies cancer_type\n",
"518 [25.0645, 0.1125, 56.3997, 3.3108, 1.6061, 12.... kirc\n",
"355 [32.6449, 2.1789, 63.4954, 6.3228, 2.109, 40.9... kirc\n",
"528 [46.024, 0.0, 85.8077, 7.2567, 2.1301, 9.6509,... kirc\n",
"445 [153.0064, 1.6403, 99.3267, 7.3736, 1.3668, 10... kirc\n",
"986 [65.5167, 18.2363, 77.2126, 5.0375, 2.4628, 21... kich\n",
"\n",
"Validation set\n",
"--------------\n",
"Dataframe shape: (207, 2)\n",
"Dataframe head:\n",
" genome_frequencies cancer_type\n",
"294 [50.8994, 0.4635, 131.5049, 5.7193, 3.103, 15.... kirp\n",
"453 [35.857, 0.1018, 94.5681, 5.2997, 1.9388, 17.6... kirc\n",
"638 [11.3865, 0.2313, 28.5961, 3.0169, 0.7851, 8.2... kirc\n",
"139 [41.6119, 0.2207, 55.4377, 4.4395, 0.884, 3.56... kirp\n",
"539 [63.1646, 18.8107, 63.2703, 4.6696, 0.9466, 5.... kirc\n"
]
}
],
"source": [
"import os\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"train_df, val_df = train_test_split(data_Frame, train_size=0.8, random_state=42)\n",
"\n",
"print(f\"DataSet shape: {data_Frame.shape}\")\n",
"print(f\"Training set{os.linesep}------------\")\n",
"print(f\"Dataframe shape: {train_df.shape}\")\n",
"print(f\"Dataframe head:{os.linesep}{train_df.head()}\")\n",
"print(\"\")\n",
"print(f\"Validation set{os.linesep}--------------\")\n",
"print(f\"Dataframe shape: {val_df.shape}\")\n",
"print(f\"Dataframe head:{os.linesep}{val_df.head()}\")"
]
},
{
"cell_type": "markdown",
"id": "4903244b-548f-4672-967d-1c62825b6fce",
"metadata": {},
"source": [
"## Building a custom PyTorch dataset"
]
},
{
"cell_type": "markdown",
"id": "7e333251-c4e7-41f0-a086-12a3d95b723f",
"metadata": {},
"source": [
"## Öffnen der Datei mit den Gesammelten Sequenzen"
]
},
{
"cell_type": "code",
2024-01-04 13:00:47 +01:00
"execution_count": 8,
"id": "e2f78725-cda6-4e8d-9029-a4a31f6f9ab7",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from torch.utils.data import Dataset\n",
"import torch\n",
"import pandas as pd\n",
"from sklearn.preprocessing import LabelEncoder\n",
"\n",
"class GenomeDataset(Dataset):\n",
" def __init__(self, dataframe):\n",
" self.dataframe = dataframe\n",
"\n",
" # Umwandlung der Genome Frequenzen in Tensoren\n",
" self.genome_frequencies = torch.tensor(dataframe['genome_frequencies'].tolist(), dtype=torch.float32)\n",
"\n",
" # Umwandlung der Krebsarten in numerische Werte\n",
" self.label_encoder = LabelEncoder()\n",
" self.cancer_types = torch.tensor(self.label_encoder.fit_transform(dataframe['cancer_type']), dtype=torch.long)\n",
"\n",
" def __getitem__(self, index):\n",
" # Rückgabe eines Tupels aus Genome Frequenzen und dem entsprechenden Krebstyp\n",
" return self.genome_frequencies[index], self.cancer_types[index]\n",
"\n",
" def __len__(self):\n",
" return len(self.dataframe)\n"
]
},
{
"cell_type": "code",
2024-01-04 13:00:47 +01:00
"execution_count": 9,
"id": "aaa2c50c-c79e-4bca-812f-1a06c9f485d5",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-01-04 13:00:47 +01:00
"/tmp/ipykernel_19797/2483914749.py:11: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /opt/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:245.)\n",
" self.genome_frequencies = torch.tensor(dataframe['genome_frequencies'].tolist(), dtype=torch.float32)\n"
]
}
],
"source": [
"# Beispielhafte Verwendung\n",
"# Angenommen, df_train und df_valid sind Ihre Trainings- und Validierungsdaten\n",
"train_dataset = GenomeDataset(train_df)\n",
"valid_dataset = GenomeDataset(val_df)"
]
},
{
"cell_type": "code",
2024-01-04 13:00:47 +01:00
"execution_count": 10,
"id": "a7fb59af-bd06-42d4-acce-03266a85bf36",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Genome frequency from dataframe:\n",
"[2.50645e+01 1.12500e-01 5.63997e+01 ... 0.00000e+00 1.29000e-02\n",
" 2.47100e-01]\n",
"\n",
"Cancer type from dataframe: kirc\n",
"\n",
"Genome frequency from dataset:\n",
"tensor([2.5065e+01, 1.1250e-01, 5.6400e+01, ..., 0.0000e+00, 1.2900e-02,\n",
" 2.4710e-01])\n",
"\n",
"Cancer type from dataset: 1\n"
]
}
],
"source": [
"# Inspect the first item from the training dataframe\n",
"train_df_head = train_df.head(n=1)\n",
"train_df_genome_frequence =train_df_head.iloc[0][\"genome_frequencies\"]\n",
"train_df_cancer_type = train_df_head.iloc[0][\"cancer_type\"]\n",
"print(f\"Genome frequency from dataframe:{os.linesep}{train_df_genome_frequence}{os.linesep}\")\n",
"print(f\"Cancer type from dataframe: {train_df_cancer_type}{os.linesep}\")\n",
"\n",
"# Inspect the first item from the training dataset\n",
"datapoint_features, datapoint_label = train_dataset[0]\n",
"print(f\"Genome frequency from dataset:{os.linesep}{datapoint_features}{os.linesep}\")\n",
"print(f\"Cancer type from dataset: {datapoint_label}\")"
]
},
2024-01-04 11:40:58 +01:00
{
"cell_type": "markdown",
"id": "418bc6a0-2ddb-4596-87d1-3e670195297c",
"metadata": {
"tags": []
},
"source": [
"## Hauptkomponentenanalyse (PCA)"
]
},
{
"cell_type": "code",
2024-01-04 13:00:47 +01:00
"execution_count": 11,
2024-01-04 11:40:58 +01:00
"id": "e6672e50-47e6-48fc-9e1e-cac0f0a606f1",
"metadata": {},
"outputs": [],
"source": [
2024-01-04 13:00:47 +01:00
"# import numpy as np\n",
"# from sklearn.decomposition import PCA\n",
"# from sklearn.preprocessing import StandardScaler\n",
2024-01-04 11:40:58 +01:00
"\n",
2024-01-04 13:00:47 +01:00
"# # Angenommen, X ist Ihr Datensatz\n",
"# # X = ...\n",
"# X = rick\n",
2024-01-04 11:40:58 +01:00
"\n",
2024-01-04 13:00:47 +01:00
"# # Standardisieren der Daten\n",
"# scaler = StandardScaler()\n",
"# X_scaled = scaler.fit_transform(X)\n",
2024-01-04 11:40:58 +01:00
"\n",
2024-01-04 13:00:47 +01:00
"# # Erstellen des PCA-Objekts\n",
"# pca = PCA(n_components=150) # Angenommen, Sie möchten 150 Hauptkomponenten behalten\n",
2024-01-04 11:40:58 +01:00
"\n",
2024-01-04 13:00:47 +01:00
"# # Durchführen der PCA\n",
"# X_pca = pca.fit_transform(X_scaled)\n",
2024-01-04 11:40:58 +01:00
"\n",
2024-01-04 13:00:47 +01:00
"# # Die resultierenden Hauptkomponenten\n",
"# print(\"Transformierte Daten:\", X_pca)\n",
2024-01-04 11:40:58 +01:00
"\n",
2024-01-04 13:00:47 +01:00
"# # Variance Ratio für jede Komponente\n",
"# print(\"Varianz erklärt durch jede Komponente:\", pca.explained_variance_ratio_)\n"
2024-01-04 11:40:58 +01:00
]
},
{
"cell_type": "markdown",
"id": "9199fdeb-0d48-44c2-8bec-db2a7d7cbd4d",
"metadata": {},
"source": [
"# Neuronales Netz Definition"
]
},
{
"cell_type": "markdown",
"id": "e53132b9-6222-4739-be49-7628e5a37709",
"metadata": {},
"source": [
"### Simples Neuronales Netz"
]
},
{
"cell_type": "code",
2024-01-04 13:00:47 +01:00
"execution_count": 12,
"id": "76b8eec8-d24b-4696-82bf-ebb286e7d1e7",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader\n",
"\n",
"# Definition des Modells\n",
"class SimpleNN(nn.Module):\n",
" def __init__(self, input_size, hidden_size, num_classes):\n",
" super(SimpleNN, self).__init__()\n",
" self.fc1 = nn.Linear(input_size, hidden_size)\n",
" self.relu = nn.ReLU()\n",
" self.fc2 = nn.Linear(hidden_size, num_classes)\n",
"\n",
" def forward(self, x):\n",
" out = self.fc1(x)\n",
" out = self.relu(out)\n",
" out = self.fc2(out)\n",
" return out"
]
},
{
"cell_type": "markdown",
"id": "e2e9e0dd-3d4f-4999-9e65-704266d5e4a2",
"metadata": {
"tags": []
},
"source": [
"### Komplexes Neuronales Netz"
]
},
{
"cell_type": "code",
2024-01-04 13:00:47 +01:00
"execution_count": 13,
"id": "944d463e-12ed-4447-8587-ee9c60ce3eb6",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
2024-01-04 11:40:58 +01:00
"import torch.nn.functional as F\n",
"\n",
"class ComplexNN(nn.Module):\n",
" def __init__(self, input_size, hidden_size, num_classes):\n",
" super(ComplexNN, self).__init__()\n",
" # Definieren der Schichten\n",
" self.fc1 = nn.Linear(input_size, 1024) # Eingabeschicht\n",
2024-01-04 11:40:58 +01:00
" self.fc2 = nn.Linear(1024, 512) # Versteckte Schicht\n",
" self.fc3 = nn.Linear(512, 256) # Weitere versteckte Schicht\n",
" self.fc4 = nn.Linear(256, num_classes) # Ausgabeschicht\n",
" self.dropout = nn.Dropout(p=0.5) # Dropout\n",
"\n",
" def forward(self, x):\n",
" # Definieren des Vorwärtsdurchlaufs\n",
2024-01-04 11:40:58 +01:00
" x = F.relu(self.fc1(x))\n",
" x = self.dropout(x)\n",
" x = F.relu(self.fc2(x))\n",
" x = self.dropout(x)\n",
" x = F.relu(self.fc3(x))\n",
" x = torch.sigmoid(self.fc4(x)) # Oder F.log_softmax(x, dim=1) für Mehrklassenklassifikation\n",
" return x"
]
},
{
"cell_type": "code",
2024-01-04 13:00:47 +01:00
"execution_count": 14,
"id": "60789428-7d6e-4737-a83a-1138f6a650f7",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# Annahme: input_size ist die Länge Ihrer Genome-Frequenzen und num_classes ist die Anzahl der Krebsarten\n",
"#model = SimpleNN(input_size=60660, hidden_size=5000, num_classes=3)\n",
"model = ComplexNN(input_size=60660, hidden_size=5000, num_classes=3)\n",
"\n",
"# Daten-Loader\n",
"train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)\n",
"valid_loader = DataLoader(dataset=valid_dataset, batch_size=64, shuffle=False)"
]
},
{
"cell_type": "code",
2024-01-04 13:00:47 +01:00
"execution_count": 15,
"id": "de6e81de-0096-443a-a0b6-90cddecf5f88",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# Verlustfunktion und Optimierer\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
2024-01-04 13:00:47 +01:00
"num_epochs = 250"
]
},
{
"cell_type": "code",
2024-01-04 13:00:47 +01:00
"execution_count": 16,
"id": "a5deb2ed-c685-4d80-bc98-d6dd27334d82",
"metadata": {
"tags": []
},
"outputs": [
{
2024-01-04 11:40:58 +01:00
"name": "stdout",
"output_type": "stream",
"text": [
2024-01-04 13:00:47 +01:00
"Epoch [1/250], Trainingsverlust: 1.0121, Validierungsverlust: 0.9179\n",
"Epoch [2/250], Trainingsverlust: 0.9557, Validierungsverlust: 0.9179\n",
"Epoch [3/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [4/250], Trainingsverlust: 0.9544, Validierungsverlust: 0.9179\n",
"Epoch [5/250], Trainingsverlust: 0.9557, Validierungsverlust: 0.9179\n",
"Epoch [6/250], Trainingsverlust: 0.9547, Validierungsverlust: 0.9179\n",
"Epoch [7/250], Trainingsverlust: 0.9556, Validierungsverlust: 0.9179\n",
"Epoch [8/250], Trainingsverlust: 0.9559, Validierungsverlust: 0.9179\n",
"Epoch [9/250], Trainingsverlust: 0.9557, Validierungsverlust: 0.9179\n",
"Epoch [10/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [11/250], Trainingsverlust: 0.9556, Validierungsverlust: 0.9179\n",
"Epoch [12/250], Trainingsverlust: 0.9559, Validierungsverlust: 0.9179\n",
"Epoch [13/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [14/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [15/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [16/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [17/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [18/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [19/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [20/250], Trainingsverlust: 0.9544, Validierungsverlust: 0.9179\n",
"Epoch [21/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [22/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [23/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [24/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [25/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [26/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [27/250], Trainingsverlust: 0.9559, Validierungsverlust: 0.9179\n",
"Epoch [28/250], Trainingsverlust: 0.9556, Validierungsverlust: 0.9179\n",
"Epoch [29/250], Trainingsverlust: 0.9548, Validierungsverlust: 0.9179\n",
"Epoch [30/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [31/250], Trainingsverlust: 0.9551, Validierungsverlust: 0.9179\n",
"Epoch [32/250], Trainingsverlust: 0.9558, Validierungsverlust: 0.9179\n",
"Epoch [33/250], Trainingsverlust: 0.9550, Validierungsverlust: 0.9179\n",
"Epoch [34/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [35/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [36/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [37/250], Trainingsverlust: 0.9550, Validierungsverlust: 0.9179\n",
"Epoch [38/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [39/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [40/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [41/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [42/250], Trainingsverlust: 0.9546, Validierungsverlust: 0.9179\n",
"Epoch [43/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [44/250], Trainingsverlust: 0.9550, Validierungsverlust: 0.9179\n",
"Epoch [45/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [46/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [47/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [48/250], Trainingsverlust: 0.9556, Validierungsverlust: 0.9179\n",
"Epoch [49/250], Trainingsverlust: 0.9551, Validierungsverlust: 0.9179\n",
"Epoch [50/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [51/250], Trainingsverlust: 0.9561, Validierungsverlust: 0.9179\n",
"Epoch [52/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [53/250], Trainingsverlust: 0.9558, Validierungsverlust: 0.9179\n",
"Epoch [54/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [55/250], Trainingsverlust: 0.9559, Validierungsverlust: 0.9179\n",
"Epoch [56/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [57/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [58/250], Trainingsverlust: 0.9557, Validierungsverlust: 0.9179\n",
"Epoch [59/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [60/250], Trainingsverlust: 0.9560, Validierungsverlust: 0.9179\n",
"Epoch [61/250], Trainingsverlust: 0.9559, Validierungsverlust: 0.9179\n",
"Epoch [62/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [63/250], Trainingsverlust: 0.9550, Validierungsverlust: 0.9179\n",
"Epoch [64/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [65/250], Trainingsverlust: 0.9558, Validierungsverlust: 0.9179\n",
"Epoch [66/250], Trainingsverlust: 0.9548, Validierungsverlust: 0.9179\n",
"Epoch [67/250], Trainingsverlust: 0.9547, Validierungsverlust: 0.9179\n",
"Epoch [68/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [69/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [70/250], Trainingsverlust: 0.9551, Validierungsverlust: 0.9179\n",
"Epoch [71/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [72/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [73/250], Trainingsverlust: 0.9557, Validierungsverlust: 0.9179\n",
"Epoch [74/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [75/250], Trainingsverlust: 0.9557, Validierungsverlust: 0.9179\n",
"Epoch [76/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [77/250], Trainingsverlust: 0.9561, Validierungsverlust: 0.9179\n",
"Epoch [78/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [79/250], Trainingsverlust: 0.9559, Validierungsverlust: 0.9179\n",
"Epoch [80/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [81/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [82/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [83/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [84/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [85/250], Trainingsverlust: 0.9557, Validierungsverlust: 0.9179\n",
"Epoch [86/250], Trainingsverlust: 0.9558, Validierungsverlust: 0.9179\n",
"Epoch [87/250], Trainingsverlust: 0.9546, Validierungsverlust: 0.9179\n",
"Epoch [88/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [89/250], Trainingsverlust: 0.9560, Validierungsverlust: 0.9179\n",
"Epoch [90/250], Trainingsverlust: 0.9564, Validierungsverlust: 0.9179\n",
"Epoch [91/250], Trainingsverlust: 0.9556, Validierungsverlust: 0.9179\n",
"Epoch [92/250], Trainingsverlust: 0.9557, Validierungsverlust: 0.9179\n",
"Epoch [93/250], Trainingsverlust: 0.9558, Validierungsverlust: 0.9179\n",
"Epoch [94/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [95/250], Trainingsverlust: 0.9557, Validierungsverlust: 0.9179\n",
"Epoch [96/250], Trainingsverlust: 0.9548, Validierungsverlust: 0.9179\n",
"Epoch [97/250], Trainingsverlust: 0.9558, Validierungsverlust: 0.9179\n",
"Epoch [98/250], Trainingsverlust: 0.9544, Validierungsverlust: 0.9179\n",
"Epoch [99/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [100/250], Trainingsverlust: 0.9543, Validierungsverlust: 0.9179\n",
"Epoch [101/250], Trainingsverlust: 0.9557, Validierungsverlust: 0.9179\n",
"Epoch [102/250], Trainingsverlust: 0.9550, Validierungsverlust: 0.9179\n",
"Epoch [103/250], Trainingsverlust: 0.9551, Validierungsverlust: 0.9179\n",
"Epoch [104/250], Trainingsverlust: 0.9544, Validierungsverlust: 0.9179\n",
"Epoch [105/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [106/250], Trainingsverlust: 0.9558, Validierungsverlust: 0.9179\n",
"Epoch [107/250], Trainingsverlust: 0.9547, Validierungsverlust: 0.9179\n",
"Epoch [108/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [109/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [110/250], Trainingsverlust: 0.9551, Validierungsverlust: 0.9179\n",
"Epoch [111/250], Trainingsverlust: 0.9551, Validierungsverlust: 0.9179\n",
"Epoch [112/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [113/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [114/250], Trainingsverlust: 0.9558, Validierungsverlust: 0.9179\n",
"Epoch [115/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [116/250], Trainingsverlust: 0.9558, Validierungsverlust: 0.9179\n",
"Epoch [117/250], Trainingsverlust: 0.9550, Validierungsverlust: 0.9179\n",
"Epoch [118/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [119/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [120/250], Trainingsverlust: 0.9556, Validierungsverlust: 0.9179\n",
"Epoch [121/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [122/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [123/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [124/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [125/250], Trainingsverlust: 0.9547, Validierungsverlust: 0.9179\n",
"Epoch [126/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [127/250], Trainingsverlust: 0.9556, Validierungsverlust: 0.9179\n",
"Epoch [128/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [129/250], Trainingsverlust: 0.9556, Validierungsverlust: 0.9179\n",
"Epoch [130/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [131/250], Trainingsverlust: 0.9545, Validierungsverlust: 0.9179\n",
"Epoch [132/250], Trainingsverlust: 0.9550, Validierungsverlust: 0.9179\n",
"Epoch [133/250], Trainingsverlust: 0.9558, Validierungsverlust: 0.9179\n",
"Epoch [134/250], Trainingsverlust: 0.9551, Validierungsverlust: 0.9179\n",
"Epoch [135/250], Trainingsverlust: 0.9551, Validierungsverlust: 0.9179\n",
"Epoch [136/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [137/250], Trainingsverlust: 0.9551, Validierungsverlust: 0.9179\n",
"Epoch [138/250], Trainingsverlust: 0.9561, Validierungsverlust: 0.9179\n",
"Epoch [139/250], Trainingsverlust: 0.9548, Validierungsverlust: 0.9179\n",
"Epoch [140/250], Trainingsverlust: 0.9548, Validierungsverlust: 0.9179\n",
"Epoch [141/250], Trainingsverlust: 0.9551, Validierungsverlust: 0.9179\n",
"Epoch [142/250], Trainingsverlust: 0.9558, Validierungsverlust: 0.9179\n",
"Epoch [143/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [144/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [145/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [146/250], Trainingsverlust: 0.9558, Validierungsverlust: 0.9179\n",
"Epoch [147/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [148/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [149/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [150/250], Trainingsverlust: 0.9546, Validierungsverlust: 0.9179\n",
"Epoch [151/250], Trainingsverlust: 0.9558, Validierungsverlust: 0.9179\n",
"Epoch [152/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [153/250], Trainingsverlust: 0.9559, Validierungsverlust: 0.9179\n",
"Epoch [154/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [155/250], Trainingsverlust: 0.9551, Validierungsverlust: 0.9179\n",
"Epoch [156/250], Trainingsverlust: 0.9551, Validierungsverlust: 0.9179\n",
"Epoch [157/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [158/250], Trainingsverlust: 0.9550, Validierungsverlust: 0.9179\n",
"Epoch [159/250], Trainingsverlust: 0.9551, Validierungsverlust: 0.9179\n",
"Epoch [160/250], Trainingsverlust: 0.9551, Validierungsverlust: 0.9179\n",
"Epoch [161/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [162/250], Trainingsverlust: 0.9556, Validierungsverlust: 0.9179\n",
"Epoch [163/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [164/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [165/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [166/250], Trainingsverlust: 0.9551, Validierungsverlust: 0.9179\n",
"Epoch [167/250], Trainingsverlust: 0.9558, Validierungsverlust: 0.9179\n",
"Epoch [168/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [169/250], Trainingsverlust: 0.9561, Validierungsverlust: 0.9179\n",
"Epoch [170/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [171/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [172/250], Trainingsverlust: 0.9544, Validierungsverlust: 0.9179\n",
"Epoch [173/250], Trainingsverlust: 0.9559, Validierungsverlust: 0.9179\n",
"Epoch [174/250], Trainingsverlust: 0.9556, Validierungsverlust: 0.9179\n",
"Epoch [175/250], Trainingsverlust: 0.9548, Validierungsverlust: 0.9179\n",
"Epoch [176/250], Trainingsverlust: 0.9557, Validierungsverlust: 0.9179\n",
"Epoch [177/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [178/250], Trainingsverlust: 0.9561, Validierungsverlust: 0.9179\n",
"Epoch [179/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [180/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [181/250], Trainingsverlust: 0.9557, Validierungsverlust: 0.9179\n",
"Epoch [182/250], Trainingsverlust: 0.9559, Validierungsverlust: 0.9179\n",
"Epoch [183/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [184/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [185/250], Trainingsverlust: 0.9551, Validierungsverlust: 0.9179\n",
"Epoch [186/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [187/250], Trainingsverlust: 0.9547, Validierungsverlust: 0.9179\n",
"Epoch [188/250], Trainingsverlust: 0.9550, Validierungsverlust: 0.9179\n",
"Epoch [189/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [190/250], Trainingsverlust: 0.9547, Validierungsverlust: 0.9179\n",
"Epoch [191/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [192/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [193/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [194/250], Trainingsverlust: 0.9556, Validierungsverlust: 0.9179\n",
"Epoch [195/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [196/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [197/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [198/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [199/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [200/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [201/250], Trainingsverlust: 0.9550, Validierungsverlust: 0.9179\n",
"Epoch [202/250], Trainingsverlust: 0.9556, Validierungsverlust: 0.9179\n",
"Epoch [203/250], Trainingsverlust: 0.9546, Validierungsverlust: 0.9179\n",
"Epoch [204/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [205/250], Trainingsverlust: 0.9546, Validierungsverlust: 0.9179\n",
"Epoch [206/250], Trainingsverlust: 0.9547, Validierungsverlust: 0.9179\n",
"Epoch [207/250], Trainingsverlust: 0.9551, Validierungsverlust: 0.9179\n",
"Epoch [208/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [209/250], Trainingsverlust: 0.9559, Validierungsverlust: 0.9179\n",
"Epoch [210/250], Trainingsverlust: 0.9550, Validierungsverlust: 0.9179\n",
"Epoch [211/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [212/250], Trainingsverlust: 0.9550, Validierungsverlust: 0.9179\n",
"Epoch [213/250], Trainingsverlust: 0.9551, Validierungsverlust: 0.9179\n",
"Epoch [214/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [215/250], Trainingsverlust: 0.9548, Validierungsverlust: 0.9179\n",
"Epoch [216/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [217/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [218/250], Trainingsverlust: 0.9547, Validierungsverlust: 0.9179\n",
"Epoch [219/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [220/250], Trainingsverlust: 0.9556, Validierungsverlust: 0.9179\n",
"Epoch [221/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [222/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [223/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [224/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [225/250], Trainingsverlust: 0.9543, Validierungsverlust: 0.9179\n",
"Epoch [226/250], Trainingsverlust: 0.9546, Validierungsverlust: 0.9179\n",
"Epoch [227/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [228/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [229/250], Trainingsverlust: 0.9547, Validierungsverlust: 0.9179\n",
"Epoch [230/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [231/250], Trainingsverlust: 0.9550, Validierungsverlust: 0.9179\n",
"Epoch [232/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [233/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [234/250], Trainingsverlust: 0.9554, Validierungsverlust: 0.9179\n",
"Epoch [235/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [236/250], Trainingsverlust: 0.9552, Validierungsverlust: 0.9179\n",
"Epoch [237/250], Trainingsverlust: 0.9550, Validierungsverlust: 0.9179\n",
"Epoch [238/250], Trainingsverlust: 0.9559, Validierungsverlust: 0.9179\n",
"Epoch [239/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [240/250], Trainingsverlust: 0.9545, Validierungsverlust: 0.9179\n",
"Epoch [241/250], Trainingsverlust: 0.9551, Validierungsverlust: 0.9179\n",
"Epoch [242/250], Trainingsverlust: 0.9550, Validierungsverlust: 0.9179\n",
"Epoch [243/250], Trainingsverlust: 0.9546, Validierungsverlust: 0.9179\n",
"Epoch [244/250], Trainingsverlust: 0.9556, Validierungsverlust: 0.9179\n",
"Epoch [245/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [246/250], Trainingsverlust: 0.9553, Validierungsverlust: 0.9179\n",
"Epoch [247/250], Trainingsverlust: 0.9549, Validierungsverlust: 0.9179\n",
"Epoch [248/250], Trainingsverlust: 0.9551, Validierungsverlust: 0.9179\n",
"Epoch [249/250], Trainingsverlust: 0.9555, Validierungsverlust: 0.9179\n",
"Epoch [250/250], Trainingsverlust: 0.9547, Validierungsverlust: 0.9179\n"
]
}
],
"source": [
"# Listen, um Verluste zu speichern\n",
"train_losses = []\n",
"valid_losses = []\n",
"\n",
"for epoch in range(num_epochs):\n",
" model.train()\n",
" train_loss = 0.0\n",
" for i, (inputs, labels) in enumerate(train_loader):\n",
" optimizer.zero_grad()\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
" train_loss += loss.item()\n",
"\n",
" # Durchschnittlicher Trainingsverlust\n",
" train_loss /= len(train_loader)\n",
" train_losses.append(train_loss)\n",
"\n",
" # Validierungsverlust\n",
" model.eval()\n",
" valid_loss = 0.0\n",
" with torch.no_grad():\n",
" for inputs, labels in valid_loader:\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs, labels)\n",
" valid_loss += loss.item()\n",
"\n",
" # Durchschnittlicher Validierungsverlust\n",
" valid_loss /= len(valid_loader)\n",
" valid_losses.append(valid_loss)\n",
"\n",
" print(f'Epoch [{epoch+1}/{num_epochs}], Trainingsverlust: {train_loss:.4f}, Validierungsverlust: {valid_loss:.4f}')"
]
},
{
"cell_type": "code",
2024-01-04 13:00:47 +01:00
"execution_count": 17,
"id": "baf1caa8-d3d9-48e8-9339-81194521528d",
"metadata": {
"tags": []
},
2024-01-04 13:00:47 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkAAAAHHCAYAAABXx+fLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAByLElEQVR4nO3deXhMVwMG8HdmksxkT8geEbIgiC2IpZaSSlCfrRWktbaWWqtqq1qLakupKtXWUrXvrZ1UKGIn1kQQCZFVZN9n7vdH5DKSkBAzZN7f88zzZO7c5dwzNzPvnHPuvRJBEAQQERER6RCptgtAREREpGkMQERERKRzGICIiIhI5zAAERERkc5hACIiIiKdwwBEREREOocBiIiIiHQOAxARERHpHAYgIiIi0jkMQFRuBgwYgGrVqr3UsjNmzIBEIinfAlUgd+/ehUQiwerVqzW+bYlEghkzZojPV69eDYlEgrt3775w2WrVqmHAgAHi86CgIEgkEgQFBZV7OanAm1jHhf/fiYmJ2Lp1a7HlKyz31q1btVPI5yjumG/bti3atm2rtTKV1rP/v/QEA5AOkEgkpXq8SR+YVHajR4+GRCLBrVu3Spznq6++gkQiweXLlzVYMqpITp48iRkzZiA5ObnUy/To0QNr166FqakpmjZtirVr18LDw+P1FfItUxj+NPEZ/TLvX0Wlp+0C0Ou3du1ated//vknDh06VGT6q34g/fbbb1CpVC+17NSpUzFp0qRX2r6uCwgIwJIlS7B+/XpMmzat2Hk2bNgAT09P1KtX76W38/HHH6N3796Qy+VlXrZ169bIysqCgYHBS2+ftOvkyZOYOXMmBgwYAAsLi1ItU69ePfGYq1q1Kj766KPXWELNOHjwYLmty8PDo8jncaH09HSMHTsWhoaGqFGjRpnXnZWVBT29J1/1L/P+VVQMQDrg2Q+bU6dO4dChQy/8EMrMzISRkVGpt6Ovr/9S5QMAPT09tX9SKjtvb2+4ublhw4YNxQag4OBgRERE4Ntvv32l7chkMshkspdaViqVQqFQvNL2n5aRkQFjY+NyW9/bLDs7m8GyFMrrmCnPura1tS3x8/ijjz5CTk4O1q9fDwcHhzKvuzz/3yoadoERgIL+7Lp16+L8+fNo3bo1jIyMMGXKFADArl270LlzZzg4OEAul8PV1RWzZ8+GUqlUW8ezY4AKx6388MMPWLFiBVxdXSGXy9GkSROcPXtWbdnixgBJJBKMHDkSO3fuRN26dSGXy1GnTh3s37+/SPmDgoLQuHFjKBQKuLq64tdffy12nYcOHcI777wDCwsLmJiYoGbNmuJ+llVJYy2KG68zYMAAmJiYIDo6Gt26dYOJiQmsra0xfvz4IvWYnJyMAQMGwNzcHBYWFujfv3+pm6sDAgIQGhqKCxcuFHlt/fr1kEgk6NOnD3JzczFt2jR4eXnB3NwcxsbGaNWqFY4cOfLCbRQ3HkIQBHzzzTeoUqUKjIyM8O677+LatWtFli2pzk6fPg0/Pz+Ym5vDyMgIbdq0wYkTJ9TmKXw/r1+/jr59+8LS0hLvvPMOgJLHY7zKMQkAW7ZsQe3ataFQKFC3bl3s2LGj2LFuGzduhJeXF0xNTWFmZgZPT08sXrwYAHDu3DlIJBKsWbOmyPoPHDgAiUSC3bt3i9Oio6MxaNAg2Nraisf8ypUri63HjRs3YurUqXB0dISRkRFSU1OLbAMoOharUHH1tmTJEtSpUwdGRkawtLRE48aNsX79egAF78GXX34JAKhevbrYNfO88WBl2TYAKJVKTJkyBXZ2djA2Nsb//vc/3Lt3r8h8r3rMlOTatWto164dDA0NUaVKFXzzzTfFtmwXV/6cnBxMnz4dbm5ukMvlcHJywoQJE5CTk/PcbZZk5cqVWLduHYYPH44ePXqovZacnIyxY8fCyckJcrkcbm5umD9/fpGyPj0G6GXev4qMP7lJ9PDhQ3Ts2BG9e/fGRx99BFtbWwAFX3gmJiYYN24cTExM8O+//2LatGlITU3F999//8L1rl+/HmlpaRg6dCgkEgm+++479OjRA3fu3Hlhq9Hx48exfft2fPbZZzA1NcVPP/2Enj17IioqCpUrVwYAXLx4EX5+frC3t8fMmTOhVCoxa9YsWFtbq63r2rVreP/991GvXj3MmjULcrkct27dKvKh+boolUr4+vrC29sbP/zwAw4fPowFCxbA1dUVw4cPB1AQJLp27Yrjx49j2LBh8PDwwI4dO9C/f/9SbSMgIAAzZ87E+vXr0ahRI7Vtb968Ga1atULVqlWRmJiI33//HX369MGnn36KtLQ0/PHHH/D19cWZM2fQoEGDMu3btGnT8M0336BTp07o1KkTLly4gA4dOiA3N/eFy/7777/o2LEjvLy8MH36dEilUqxatQrt2rXDf//9h6ZNm6rN/+GHH8Ld3R1z586FIAhlKmeh0hyTe/bsgb+/Pzw9PTFv3jw8evQIgwcPhqOjo9q6Dh06hD59+qB9+/aYP38+AODGjRs4ceIExowZg8aNG8PFxQWbN28u8j5u2rQJlpaW8PX1BQDExcWhWbNmYvi3trbGvn37MHjwYKSmpmLs2LFqy8+ePRsGBgYYP348cnJyXrlV4rfffsPo0aPxwQcfYMyYMcjOzsbly5dx+vRp9O3bFz169MDNmzexYcMG/Pjjj7CysgKAIv9rr2LOnDmQSCSYOHEi4uPjsWjRIvj4+ODSpUswNDQE8PqOmdjYWLz77rvIz8/HpEmTYGxsjBUrVojbfR6VSoX//e9/OH78OIYMGQIPDw9cuXIFP/74I27evImdO3eWqR5u3LiBUaNGoV69eli4cKHaa5mZmWjTpg2io6MxdOhQVK1aFSdPnsTkyZMRExODRYsWFbtOTbx/bxWBdM6IESOEZ9/6Nm3aCACE5cuXF5k/MzOzyLShQ4cKRkZGQnZ2tjitf//+grOzs/g8IiJCACBUrlxZSEpKEqfv2rVLACD8888/4rTp06cXKRMAwcDAQLh165Y4LSQkRAAgLFmyRJzWpUsXwcjISIiOjhanhYeHC3p6emrr/PHHHwUAQkJCQrH1UlZHjhwRAAhHjhxRm16436tWrRKn9e/fXwAgzJo1S23ehg0bCl5eXuLznTt3CgCE7777TpyWn58vtGrVqsg6S9KkSROhSpUqglKpFKft379fACD8+uuv4jpzcnLUlnv06JFga2srDBo0SG06AGH69Oni81WrVgkAhIiICEEQBCE+Pl4wMDAQOnfuLKhUKnG+KVOmCACE/v37i9OerTOVSiW4u7sLvr6+astmZmYK1atXF9577z1xWuEx0qdPnyL73KZNG6FNmzZFpr/KMenp6SlUqVJFSEtLE6cFBQUJANTWOWbMGMHMzEzIz88vsv1CkydPFvT19dW2mZOTI1hYWKjV9+DBgwV7e3shMTFRbfnevXsL5ubm4v9iYT26uLgU+f8s7rh0dnZWex8KPVtvXbt2FerUqVPifgiCIHz//fdq7/+LlHbbheV2dHQUUlNTxembN28WAAiLFy8WBKH8jpnijB07VgAgnD59WpwWHx8vmJubF9nnZ8u/du1aQSqVCv/995/aOpcvXy4AEE6cOFGqMhTuS926dQUjIyPhxo0bRV6fPXu2YGxsLNy8eVNt+qRJkwSZTCZERUWJ0579/y3r+1eRsQuMRHK5HAMHDiwy/elfP2lpaUhMTESrVq2QmZmJ0NDQF67X398flpaW4vNWrVoBAO7cufPCZX18fODq6io+r1evHszMzMRllUolDh8+jG7duqn1j7u5uaFjx45q6yoc8Ldr166XHqz9qoYNG6b2vFWrVmr1sHfvXujp6YktQkDBmJtRo0aVehsfffQR7t+/j2PHjonT1q9fDwMDA3z44YfiOgtbC1QqFZKSkpCfn4/GjRsX2332PIcPH0Zubi5GjRql1uX4bGtFcS5duoTw8HD07dsXDx8+RGJiIhITE5GRkYH27dvj2LFjRd6rZ+vwZbzomHzw4AGuXLmCfv36wcTERJyvTZs28PT0VFuXhYUFMjIycOjQoeduLy8vD9u3bxenHTx
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"plt.plot(train_losses, label='Trainingsverlust')\n",
"plt.plot(valid_losses, label='Validierungsverlust')\n",
"plt.xlabel('Epochen')\n",
"plt.ylabel('Verlust')\n",
"plt.title('Trainings- und Validierungsverlust über die Zeit')\n",
"plt.legend()\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
2024-01-04 13:00:47 +01:00
"execution_count": 18,
"id": "8e339354-a7cc-4e8a-9323-4be41ef62117",
"metadata": {},
"outputs": [],
"source": [
"# Laden der 'kirp' Liste aus der Pickle-Datei\n",
"with open('rick.pickle', 'rb') as f:\n",
" rick = pickle.load(f)\n"
]
},
{
"cell_type": "markdown",
"id": "be10a487-728e-4953-a081-9103d485378c",
"metadata": {},
"source": [
"## Hauptkomponentenanalyse (PCA)"
]
},
{
"cell_type": "code",
2024-01-04 13:00:47 +01:00
"execution_count": 19,
"id": "088db0b3-8c33-41ff-a543-1b1e50c5e589",
"metadata": {
"tags": []
},
"outputs": [
{
2024-01-04 13:00:47 +01:00
"ename": "ValueError",
"evalue": "setting an array element with a sequence.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;31mTypeError\u001b[0m: only size-1 arrays can be converted to Python scalars",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_19797/2932508590.py\u001b[0m in \u001b[0;36m?\u001b[0;34m()\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msklearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecomposition\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mPCA\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msklearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpreprocessing\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mStandardScaler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/.local/lib/python3.8/site-packages/sklearn/utils/_set_output.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, X, *args, **kwargs)\u001b[0m\n\u001b[1;32m 155\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mwraps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 156\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 157\u001b[0;31m \u001b[0mdata_to_wrap\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 158\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_to_wrap\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0;31m# only wrap the first output for cross decomposition\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 160\u001b[0m return_tuple = (\n",
"\u001b[0;32m~/.local/lib/python3.8/site-packages/sklearn/base.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, X, y, **fit_params)\u001b[0m\n\u001b[1;32m 912\u001b[0m \u001b[0;31m# non-optimized default implementation; override when a better\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 913\u001b[0m \u001b[0;31m# method is possible for a given clustering algorithm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 914\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0my\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 915\u001b[0m \u001b[0;31m# fit method of arity 1 (unsupervised transformation)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 916\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 917\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 918\u001b[0m \u001b[0;31m# fit method of arity 2 (supervised transformation)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 919\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/.local/lib/python3.8/site-packages/sklearn/preprocessing/_data.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[1;32m 835\u001b[0m \u001b[0mFitted\u001b[0m \u001b[0mscaler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 836\u001b[0m \"\"\"\n\u001b[1;32m 837\u001b[0m \u001b[0;31m# Reset internal state before fitting\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 838\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 839\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpartial_fit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/.local/lib/python3.8/site-packages/sklearn/base.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1148\u001b[0m skip_parameter_validation=(\n\u001b[1;32m 1149\u001b[0m \u001b[0mprefer_skip_nested_validation\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mglobal_skip_validation\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1150\u001b[0m )\n\u001b[1;32m 1151\u001b[0m ):\n\u001b[0;32m-> 1152\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfit_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/.local/lib/python3.8/site-packages/sklearn/preprocessing/_data.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[1;32m 871\u001b[0m \u001b[0mself\u001b[0m \u001b[0;34m:\u001b[0m \u001b[0mobject\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 872\u001b[0m \u001b[0mFitted\u001b[0m \u001b[0mscaler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 873\u001b[0m \"\"\"\n\u001b[1;32m 874\u001b[0m \u001b[0mfirst_call\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"n_samples_seen_\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 875\u001b[0;31m X = self._validate_data(\n\u001b[0m\u001b[1;32m 876\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 877\u001b[0m \u001b[0maccept_sparse\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"csr\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"csc\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 878\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mFLOAT_DTYPES\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/.local/lib/python3.8/site-packages/sklearn/base.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, X, y, reset, validate_separately, cast_to_ndarray, **check_params)\u001b[0m\n\u001b[1;32m 601\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 602\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 603\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 604\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mno_val_X\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mno_val_y\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 605\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcheck_array\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"X\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mcheck_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 606\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mno_val_X\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mno_val_y\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 607\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_check_y\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mcheck_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 608\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/.local/lib/python3.8/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator, input_name)\u001b[0m\n\u001b[1;32m 913\u001b[0m \u001b[0marray\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mxp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mastype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 914\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 915\u001b[0m \u001b[0marray\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_asarray_with_order\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0morder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mxp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 916\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mComplexWarning\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mcomplex_warning\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 917\u001b[0;31m raise ValueError(\n\u001b[0m\u001b[1;32m 918\u001b[0m \u001b[0;34m\"Complex data not supported\\n{}\\n\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 919\u001b[0m ) from complex_warning\n\u001b[1;32m 920\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/.local/lib/python3.8/site-packages/sklearn/utils/_array_api.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(array, dtype, order, copy, xp)\u001b[0m\n\u001b[1;32m 376\u001b[0m \u001b[0;31m# Use NumPy API to support order\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 377\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcopy\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 378\u001b[0m \u001b[0marray\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0morder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 379\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 380\u001b[0;31m \u001b[0marray\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0morder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 381\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 382\u001b[0m \u001b[0;31m# At this point array is a NumPy ndarray. We convert it to an array\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 383\u001b[0m \u001b[0;31m# container that is consistent with the input's namespace.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/rl/lib/python3.8/site-packages/pandas/core/generic.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, dtype)\u001b[0m\n\u001b[1;32m 1996\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__array__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mnpt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDTypeLike\u001b[0m \u001b[0;34m|\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1997\u001b[0m \u001b[0mvalues\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_values\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1998\u001b[0;31m \u001b[0marr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1999\u001b[0m if (\n\u001b[1;32m 2000\u001b[0m \u001b[0mastype_is_view\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2001\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0musing_copy_on_write\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: setting an array element with a sequence."
]
}
],
"source": [
"import numpy as np\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"# Angenommen, X ist Ihr Datensatz\n",
"# X = ...\n",
"X = rick\n",
"\n",
"# Standardisieren der Daten\n",
"scaler = StandardScaler()\n",
"X_scaled = scaler.fit_transform(X)\n",
"\n",
"# Erstellen des PCA-Objekts\n",
"pca = PCA(n_components=150) # Angenommen, Sie möchten 150 Hauptkomponenten behalten\n",
"\n",
"# Durchführen der PCA\n",
"X_pca = pca.fit_transform(X_scaled)\n",
"\n",
"# Die resultierenden Hauptkomponenten\n",
"print(\"Transformierte Daten:\", X_pca)\n",
"\n",
"# Variance Ratio für jede Komponente\n",
"print(\"Varianz erklärt durch jede Komponente:\", pca.explained_variance_ratio_)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b11bbe20-0494-4e7a-83ff-3cb0bfa82f3b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2024-01-04 13:00:47 +01:00
"version": "3.8.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}