src.train
Training module for Hepatitis C classification model.
This module contains the ModelTrainer class and training utilities, separated from model definitions for better code organization.
1""" 2Training module for Hepatitis C classification model. 3 4This module contains the ModelTrainer class and training utilities, 5separated from model definitions for better code organization. 6""" 7 8from __future__ import annotations 9import torch 10import torch.nn as nn 11import torch.optim as optim 12from torch.utils.data import DataLoader 13import numpy as np 14import time 15 16 17class ModelTrainer: 18 """ 19 Class to handle training and validation of neural network models. 20 21 This trainer can be reused with different models and datasets, 22 following the separation of concerns principle. 23 24 Parameters 25 ----------- 26 model : nn.Module 27 The neural network model to be trained. 28 device : str 29 Device to run the training on ('cpu' or 'cuda'). 30 31 Attributes 32 ----------- 33 model : nn.Module 34 The neural network model to be trained. 35 device : str 36 Device to run the training on ('cpu' or 'cuda'). 37 history : dict 38 Dictionary to store training history (losses and accuracies). 39 40 Examples 41 --------- 42 >>> from src.models import HepatitisNet 43 >>> from src.train import ModelTrainer 44 >>> model = HepatitisNet(input_size=12) 45 >>> trainer = ModelTrainer(model, device='cuda') 46 >>> history = trainer.train(train_loader, val_loader, epochs=50) 47 """ 48 49 def __init__(self, model: nn.Module, device: str = 'cpu'): 50 self.model = model.to(device) 51 self.device = device 52 self.history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []} 53 54 def train_epoch(self, train_loader: DataLoader, criterion: nn.Module, optimizer: optim.Optimizer) -> tuple[float, float]: 55 """ 56 Train the model for one epoch. 57 58 Parameters 59 ----------- 60 train_loader : DataLoader 61 DataLoader for training data. 62 criterion : nn.Module 63 Loss function. 64 optimizer : optim.Optimizer 65 Optimizer for updating weights. 66 67 Returns 68 ----------- 69 tuple[float, float] 70 Average training loss and accuracy for the epoch. 71 """ 72 self.model.train() 73 total_loss = 0 74 correct = 0 75 total = 0 76 77 for data, target in train_loader: 78 data, target = data.to(self.device), target.to(self.device) 79 80 optimizer.zero_grad() 81 output = self.model(data) 82 loss = criterion(output, target) 83 loss.backward() 84 optimizer.step() 85 86 total_loss += loss.item() 87 pred = output.argmax(dim=1) 88 correct += pred.eq(target).sum().item() 89 total += target.size(0) 90 91 return total_loss / len(train_loader), 100. * correct / total 92 93 def validate_epoch(self, val_loader: DataLoader, criterion: nn.Module) -> tuple[float, float]: 94 """ 95 Validate the model for one epoch. 96 97 Parameters 98 ----------- 99 val_loader : DataLoader 100 DataLoader for validation data. 101 criterion : nn.Module 102 Loss function. 103 104 Returns 105 ----------- 106 tuple[float, float] 107 Average validation loss and accuracy for the epoch. 108 """ 109 self.model.eval() 110 total_loss = 0 111 correct = 0 112 total = 0 113 114 with torch.no_grad(): 115 for data, target in val_loader: 116 data, target = data.to(self.device), target.to(self.device) 117 output = self.model(data) 118 loss = criterion(output, target) 119 120 total_loss += loss.item() 121 pred = output.argmax(dim=1) 122 correct += pred.eq(target).sum().item() 123 total += target.size(0) 124 125 return total_loss / len(val_loader), 100. * correct / total 126 127 def train(self, train_loader: DataLoader, val_loader: DataLoader, epochs: int = 50, learning_rate: float = 0.001) -> dict: 128 """ 129 Train the model for multiple epochs with early stopping. 130 131 Parameters 132 ----------- 133 train_loader : DataLoader 134 DataLoader for training data. 135 val_loader : DataLoader 136 DataLoader for validation data. 137 epochs : int 138 Maximum number of epochs to train. 139 learning_rate : float 140 Learning rate for the optimizer. 141 142 Returns 143 ----------- 144 dict 145 Training history containing losses and accuracies. 146 147 Examples 148 --------- 149 >>> history = trainer.train(train_loader, val_loader, epochs=100, learning_rate=0.001) 150 >>> print(f"Best validation accuracy: {max(history['val_acc']):.2f}%") 151 """ 152 criterion = nn.CrossEntropyLoss() 153 optimizer = optim.Adam(self.model.parameters(), lr=learning_rate, weight_decay=1e-4) 154 scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5) 155 156 best_val_acc = 0 157 patience_counter = 0 158 patience = 10 159 160 print(f"Training on {self.device}") 161 print(f"Epochs: {epochs}, Learning Rate: {learning_rate}") 162 print("-" * 50) 163 164 start_time = time.time() 165 166 for epoch in range(epochs): 167 train_loss, train_acc = self.train_epoch(train_loader, criterion, optimizer) 168 val_loss, val_acc = self.validate_epoch(val_loader, criterion) 169 scheduler.step(val_loss) 170 171 self.history['train_loss'].append(train_loss) 172 self.history['train_acc'].append(train_acc) 173 self.history['val_loss'].append(val_loss) 174 self.history['val_acc'].append(val_acc) 175 176 if epoch % 10 == 0: 177 print(f'Epoch {epoch:3d}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.1f}%, ' 178 f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.1f}%') 179 180 if val_acc > best_val_acc: 181 best_val_acc = val_acc 182 patience_counter = 0 183 self.best_model_state = self.model.state_dict().copy() 184 else: 185 patience_counter += 1 186 if patience_counter >= patience: 187 print(f"Early stopping at epoch {epoch}") 188 break 189 190 self.model.load_state_dict(self.best_model_state) 191 training_time = time.time() - start_time 192 print(f"\nTraining completed in {training_time:.2f} seconds") 193 print(f"Best validation accuracy: {best_val_acc:.2f}%") 194 195 return self.history
class
ModelTrainer:
18class ModelTrainer: 19 """ 20 Class to handle training and validation of neural network models. 21 22 This trainer can be reused with different models and datasets, 23 following the separation of concerns principle. 24 25 Parameters 26 ----------- 27 model : nn.Module 28 The neural network model to be trained. 29 device : str 30 Device to run the training on ('cpu' or 'cuda'). 31 32 Attributes 33 ----------- 34 model : nn.Module 35 The neural network model to be trained. 36 device : str 37 Device to run the training on ('cpu' or 'cuda'). 38 history : dict 39 Dictionary to store training history (losses and accuracies). 40 41 Examples 42 --------- 43 >>> from src.models import HepatitisNet 44 >>> from src.train import ModelTrainer 45 >>> model = HepatitisNet(input_size=12) 46 >>> trainer = ModelTrainer(model, device='cuda') 47 >>> history = trainer.train(train_loader, val_loader, epochs=50) 48 """ 49 50 def __init__(self, model: nn.Module, device: str = 'cpu'): 51 self.model = model.to(device) 52 self.device = device 53 self.history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []} 54 55 def train_epoch(self, train_loader: DataLoader, criterion: nn.Module, optimizer: optim.Optimizer) -> tuple[float, float]: 56 """ 57 Train the model for one epoch. 58 59 Parameters 60 ----------- 61 train_loader : DataLoader 62 DataLoader for training data. 63 criterion : nn.Module 64 Loss function. 65 optimizer : optim.Optimizer 66 Optimizer for updating weights. 67 68 Returns 69 ----------- 70 tuple[float, float] 71 Average training loss and accuracy for the epoch. 72 """ 73 self.model.train() 74 total_loss = 0 75 correct = 0 76 total = 0 77 78 for data, target in train_loader: 79 data, target = data.to(self.device), target.to(self.device) 80 81 optimizer.zero_grad() 82 output = self.model(data) 83 loss = criterion(output, target) 84 loss.backward() 85 optimizer.step() 86 87 total_loss += loss.item() 88 pred = output.argmax(dim=1) 89 correct += pred.eq(target).sum().item() 90 total += target.size(0) 91 92 return total_loss / len(train_loader), 100. * correct / total 93 94 def validate_epoch(self, val_loader: DataLoader, criterion: nn.Module) -> tuple[float, float]: 95 """ 96 Validate the model for one epoch. 97 98 Parameters 99 ----------- 100 val_loader : DataLoader 101 DataLoader for validation data. 102 criterion : nn.Module 103 Loss function. 104 105 Returns 106 ----------- 107 tuple[float, float] 108 Average validation loss and accuracy for the epoch. 109 """ 110 self.model.eval() 111 total_loss = 0 112 correct = 0 113 total = 0 114 115 with torch.no_grad(): 116 for data, target in val_loader: 117 data, target = data.to(self.device), target.to(self.device) 118 output = self.model(data) 119 loss = criterion(output, target) 120 121 total_loss += loss.item() 122 pred = output.argmax(dim=1) 123 correct += pred.eq(target).sum().item() 124 total += target.size(0) 125 126 return total_loss / len(val_loader), 100. * correct / total 127 128 def train(self, train_loader: DataLoader, val_loader: DataLoader, epochs: int = 50, learning_rate: float = 0.001) -> dict: 129 """ 130 Train the model for multiple epochs with early stopping. 131 132 Parameters 133 ----------- 134 train_loader : DataLoader 135 DataLoader for training data. 136 val_loader : DataLoader 137 DataLoader for validation data. 138 epochs : int 139 Maximum number of epochs to train. 140 learning_rate : float 141 Learning rate for the optimizer. 142 143 Returns 144 ----------- 145 dict 146 Training history containing losses and accuracies. 147 148 Examples 149 --------- 150 >>> history = trainer.train(train_loader, val_loader, epochs=100, learning_rate=0.001) 151 >>> print(f"Best validation accuracy: {max(history['val_acc']):.2f}%") 152 """ 153 criterion = nn.CrossEntropyLoss() 154 optimizer = optim.Adam(self.model.parameters(), lr=learning_rate, weight_decay=1e-4) 155 scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5) 156 157 best_val_acc = 0 158 patience_counter = 0 159 patience = 10 160 161 print(f"Training on {self.device}") 162 print(f"Epochs: {epochs}, Learning Rate: {learning_rate}") 163 print("-" * 50) 164 165 start_time = time.time() 166 167 for epoch in range(epochs): 168 train_loss, train_acc = self.train_epoch(train_loader, criterion, optimizer) 169 val_loss, val_acc = self.validate_epoch(val_loader, criterion) 170 scheduler.step(val_loss) 171 172 self.history['train_loss'].append(train_loss) 173 self.history['train_acc'].append(train_acc) 174 self.history['val_loss'].append(val_loss) 175 self.history['val_acc'].append(val_acc) 176 177 if epoch % 10 == 0: 178 print(f'Epoch {epoch:3d}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.1f}%, ' 179 f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.1f}%') 180 181 if val_acc > best_val_acc: 182 best_val_acc = val_acc 183 patience_counter = 0 184 self.best_model_state = self.model.state_dict().copy() 185 else: 186 patience_counter += 1 187 if patience_counter >= patience: 188 print(f"Early stopping at epoch {epoch}") 189 break 190 191 self.model.load_state_dict(self.best_model_state) 192 training_time = time.time() - start_time 193 print(f"\nTraining completed in {training_time:.2f} seconds") 194 print(f"Best validation accuracy: {best_val_acc:.2f}%") 195 196 return self.history
Class to handle training and validation of neural network models.
This trainer can be reused with different models and datasets, following the separation of concerns principle.
Parameters
- model (nn.Module): The neural network model to be trained.
- device (str): Device to run the training on ('cpu' or 'cuda').
Attributes
- model (nn.Module): The neural network model to be trained.
- device (str): Device to run the training on ('cpu' or 'cuda').
- history (dict): Dictionary to store training history (losses and accuracies).
Examples
>>> from src.models import HepatitisNet
>>> from src.train import ModelTrainer
>>> model = HepatitisNet(input_size=12)
>>> trainer = ModelTrainer(model, device='cuda')
>>> history = trainer.train(train_loader, val_loader, epochs=50)
def
train_epoch( self, train_loader: torch.utils.data.dataloader.DataLoader, criterion: torch.nn.modules.module.Module, optimizer: torch.optim.optimizer.Optimizer) -> tuple[float, float]:
55 def train_epoch(self, train_loader: DataLoader, criterion: nn.Module, optimizer: optim.Optimizer) -> tuple[float, float]: 56 """ 57 Train the model for one epoch. 58 59 Parameters 60 ----------- 61 train_loader : DataLoader 62 DataLoader for training data. 63 criterion : nn.Module 64 Loss function. 65 optimizer : optim.Optimizer 66 Optimizer for updating weights. 67 68 Returns 69 ----------- 70 tuple[float, float] 71 Average training loss and accuracy for the epoch. 72 """ 73 self.model.train() 74 total_loss = 0 75 correct = 0 76 total = 0 77 78 for data, target in train_loader: 79 data, target = data.to(self.device), target.to(self.device) 80 81 optimizer.zero_grad() 82 output = self.model(data) 83 loss = criterion(output, target) 84 loss.backward() 85 optimizer.step() 86 87 total_loss += loss.item() 88 pred = output.argmax(dim=1) 89 correct += pred.eq(target).sum().item() 90 total += target.size(0) 91 92 return total_loss / len(train_loader), 100. * correct / total
Train the model for one epoch.
Parameters
- train_loader (DataLoader): DataLoader for training data.
- criterion (nn.Module): Loss function.
- optimizer (optim.Optimizer): Optimizer for updating weights.
Returns
- tuple[float, float]: Average training loss and accuracy for the epoch.
def
validate_epoch( self, val_loader: torch.utils.data.dataloader.DataLoader, criterion: torch.nn.modules.module.Module) -> tuple[float, float]:
94 def validate_epoch(self, val_loader: DataLoader, criterion: nn.Module) -> tuple[float, float]: 95 """ 96 Validate the model for one epoch. 97 98 Parameters 99 ----------- 100 val_loader : DataLoader 101 DataLoader for validation data. 102 criterion : nn.Module 103 Loss function. 104 105 Returns 106 ----------- 107 tuple[float, float] 108 Average validation loss and accuracy for the epoch. 109 """ 110 self.model.eval() 111 total_loss = 0 112 correct = 0 113 total = 0 114 115 with torch.no_grad(): 116 for data, target in val_loader: 117 data, target = data.to(self.device), target.to(self.device) 118 output = self.model(data) 119 loss = criterion(output, target) 120 121 total_loss += loss.item() 122 pred = output.argmax(dim=1) 123 correct += pred.eq(target).sum().item() 124 total += target.size(0) 125 126 return total_loss / len(val_loader), 100. * correct / total
Validate the model for one epoch.
Parameters
- val_loader (DataLoader): DataLoader for validation data.
- criterion (nn.Module): Loss function.
Returns
- tuple[float, float]: Average validation loss and accuracy for the epoch.
def
train( self, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, epochs: int = 50, learning_rate: float = 0.001) -> dict:
128 def train(self, train_loader: DataLoader, val_loader: DataLoader, epochs: int = 50, learning_rate: float = 0.001) -> dict: 129 """ 130 Train the model for multiple epochs with early stopping. 131 132 Parameters 133 ----------- 134 train_loader : DataLoader 135 DataLoader for training data. 136 val_loader : DataLoader 137 DataLoader for validation data. 138 epochs : int 139 Maximum number of epochs to train. 140 learning_rate : float 141 Learning rate for the optimizer. 142 143 Returns 144 ----------- 145 dict 146 Training history containing losses and accuracies. 147 148 Examples 149 --------- 150 >>> history = trainer.train(train_loader, val_loader, epochs=100, learning_rate=0.001) 151 >>> print(f"Best validation accuracy: {max(history['val_acc']):.2f}%") 152 """ 153 criterion = nn.CrossEntropyLoss() 154 optimizer = optim.Adam(self.model.parameters(), lr=learning_rate, weight_decay=1e-4) 155 scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5) 156 157 best_val_acc = 0 158 patience_counter = 0 159 patience = 10 160 161 print(f"Training on {self.device}") 162 print(f"Epochs: {epochs}, Learning Rate: {learning_rate}") 163 print("-" * 50) 164 165 start_time = time.time() 166 167 for epoch in range(epochs): 168 train_loss, train_acc = self.train_epoch(train_loader, criterion, optimizer) 169 val_loss, val_acc = self.validate_epoch(val_loader, criterion) 170 scheduler.step(val_loss) 171 172 self.history['train_loss'].append(train_loss) 173 self.history['train_acc'].append(train_acc) 174 self.history['val_loss'].append(val_loss) 175 self.history['val_acc'].append(val_acc) 176 177 if epoch % 10 == 0: 178 print(f'Epoch {epoch:3d}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.1f}%, ' 179 f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.1f}%') 180 181 if val_acc > best_val_acc: 182 best_val_acc = val_acc 183 patience_counter = 0 184 self.best_model_state = self.model.state_dict().copy() 185 else: 186 patience_counter += 1 187 if patience_counter >= patience: 188 print(f"Early stopping at epoch {epoch}") 189 break 190 191 self.model.load_state_dict(self.best_model_state) 192 training_time = time.time() - start_time 193 print(f"\nTraining completed in {training_time:.2f} seconds") 194 print(f"Best validation accuracy: {best_val_acc:.2f}%") 195 196 return self.history
Train the model for multiple epochs with early stopping.
Parameters
- train_loader (DataLoader): DataLoader for training data.
- val_loader (DataLoader): DataLoader for validation data.
- epochs (int): Maximum number of epochs to train.
- learning_rate (float): Learning rate for the optimizer.
Returns
- dict: Training history containing losses and accuracies.
Examples
>>> history = trainer.train(train_loader, val_loader, epochs=100, learning_rate=0.001)
>>> print(f"Best validation accuracy: {max(history['val_acc']):.2f}%")