src.train

Training module for Hepatitis C classification model.

This module contains the ModelTrainer class and training utilities, separated from model definitions for better code organization.

  1"""
  2Training module for Hepatitis C classification model.
  3
  4This module contains the ModelTrainer class and training utilities,
  5separated from model definitions for better code organization.
  6"""
  7
  8from __future__ import annotations
  9import torch
 10import torch.nn as nn
 11import torch.optim as optim
 12from torch.utils.data import DataLoader
 13import numpy as np
 14import time
 15
 16
 17class ModelTrainer:
 18    """
 19    Class to handle training and validation of neural network models.
 20    
 21    This trainer can be reused with different models and datasets,
 22    following the separation of concerns principle.
 23
 24    Parameters
 25    -----------
 26    model : nn.Module
 27        The neural network model to be trained.
 28    device : str
 29        Device to run the training on ('cpu' or 'cuda').
 30    
 31    Attributes
 32    -----------
 33    model : nn.Module
 34        The neural network model to be trained.
 35    device : str
 36        Device to run the training on ('cpu' or 'cuda').
 37    history : dict
 38        Dictionary to store training history (losses and accuracies).
 39        
 40    Examples
 41    ---------
 42    >>> from src.models import HepatitisNet
 43    >>> from src.train import ModelTrainer
 44    >>> model = HepatitisNet(input_size=12)
 45    >>> trainer = ModelTrainer(model, device='cuda')
 46    >>> history = trainer.train(train_loader, val_loader, epochs=50)
 47    """
 48
 49    def __init__(self, model: nn.Module, device: str = 'cpu'):
 50        self.model = model.to(device)
 51        self.device = device
 52        self.history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
 53
 54    def train_epoch(self, train_loader: DataLoader, criterion: nn.Module, optimizer: optim.Optimizer) -> tuple[float, float]:
 55        """
 56        Train the model for one epoch.
 57        
 58        Parameters
 59        -----------
 60        train_loader : DataLoader
 61            DataLoader for training data.
 62        criterion : nn.Module
 63            Loss function.
 64        optimizer : optim.Optimizer
 65            Optimizer for updating weights.
 66            
 67        Returns
 68        -----------
 69        tuple[float, float]
 70            Average training loss and accuracy for the epoch.
 71        """
 72        self.model.train()
 73        total_loss = 0
 74        correct = 0
 75        total = 0
 76        
 77        for data, target in train_loader:
 78            data, target = data.to(self.device), target.to(self.device)
 79            
 80            optimizer.zero_grad()
 81            output = self.model(data)
 82            loss = criterion(output, target)
 83            loss.backward()
 84            optimizer.step()
 85            
 86            total_loss += loss.item()
 87            pred = output.argmax(dim=1)
 88            correct += pred.eq(target).sum().item()
 89            total += target.size(0)
 90        
 91        return total_loss / len(train_loader), 100. * correct / total
 92
 93    def validate_epoch(self, val_loader: DataLoader, criterion: nn.Module) -> tuple[float, float]:
 94        """
 95        Validate the model for one epoch.
 96        
 97        Parameters
 98        -----------
 99        val_loader : DataLoader
100            DataLoader for validation data.
101        criterion : nn.Module
102            Loss function.
103            
104        Returns
105        -----------
106        tuple[float, float]
107            Average validation loss and accuracy for the epoch.
108        """
109        self.model.eval()
110        total_loss = 0
111        correct = 0
112        total = 0
113        
114        with torch.no_grad():
115            for data, target in val_loader:
116                data, target = data.to(self.device), target.to(self.device)
117                output = self.model(data)
118                loss = criterion(output, target)
119                
120                total_loss += loss.item()
121                pred = output.argmax(dim=1)
122                correct += pred.eq(target).sum().item()
123                total += target.size(0)
124        
125        return total_loss / len(val_loader), 100. * correct / total
126
127    def train(self, train_loader: DataLoader, val_loader: DataLoader, epochs: int = 50, learning_rate: float = 0.001) -> dict:
128        """
129        Train the model for multiple epochs with early stopping.
130        
131        Parameters
132        -----------
133        train_loader : DataLoader
134            DataLoader for training data.
135        val_loader : DataLoader
136            DataLoader for validation data.
137        epochs : int
138            Maximum number of epochs to train.
139        learning_rate : float
140            Learning rate for the optimizer.
141            
142        Returns
143        -----------
144        dict
145            Training history containing losses and accuracies.
146            
147        Examples
148        ---------
149        >>> history = trainer.train(train_loader, val_loader, epochs=100, learning_rate=0.001)
150        >>> print(f"Best validation accuracy: {max(history['val_acc']):.2f}%")
151        """
152        criterion = nn.CrossEntropyLoss()
153        optimizer = optim.Adam(self.model.parameters(), lr=learning_rate, weight_decay=1e-4)
154        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
155        
156        best_val_acc = 0
157        patience_counter = 0
158        patience = 10
159        
160        print(f"Training on {self.device}")
161        print(f"Epochs: {epochs}, Learning Rate: {learning_rate}")
162        print("-" * 50)
163        
164        start_time = time.time()
165        
166        for epoch in range(epochs):
167            train_loss, train_acc = self.train_epoch(train_loader, criterion, optimizer)
168            val_loss, val_acc = self.validate_epoch(val_loader, criterion)
169            scheduler.step(val_loss)
170            
171            self.history['train_loss'].append(train_loss)
172            self.history['train_acc'].append(train_acc)
173            self.history['val_loss'].append(val_loss)
174            self.history['val_acc'].append(val_acc)
175            
176            if epoch % 10 == 0:
177                print(f'Epoch {epoch:3d}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.1f}%, '
178                      f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.1f}%')
179            
180            if val_acc > best_val_acc:
181                best_val_acc = val_acc
182                patience_counter = 0
183                self.best_model_state = self.model.state_dict().copy()
184            else:
185                patience_counter += 1
186                if patience_counter >= patience:
187                    print(f"Early stopping at epoch {epoch}")
188                    break
189        
190        self.model.load_state_dict(self.best_model_state)
191        training_time = time.time() - start_time
192        print(f"\nTraining completed in {training_time:.2f} seconds")
193        print(f"Best validation accuracy: {best_val_acc:.2f}%")
194        
195        return self.history
class ModelTrainer:
 18class ModelTrainer:
 19    """
 20    Class to handle training and validation of neural network models.
 21    
 22    This trainer can be reused with different models and datasets,
 23    following the separation of concerns principle.
 24
 25    Parameters
 26    -----------
 27    model : nn.Module
 28        The neural network model to be trained.
 29    device : str
 30        Device to run the training on ('cpu' or 'cuda').
 31    
 32    Attributes
 33    -----------
 34    model : nn.Module
 35        The neural network model to be trained.
 36    device : str
 37        Device to run the training on ('cpu' or 'cuda').
 38    history : dict
 39        Dictionary to store training history (losses and accuracies).
 40        
 41    Examples
 42    ---------
 43    >>> from src.models import HepatitisNet
 44    >>> from src.train import ModelTrainer
 45    >>> model = HepatitisNet(input_size=12)
 46    >>> trainer = ModelTrainer(model, device='cuda')
 47    >>> history = trainer.train(train_loader, val_loader, epochs=50)
 48    """
 49
 50    def __init__(self, model: nn.Module, device: str = 'cpu'):
 51        self.model = model.to(device)
 52        self.device = device
 53        self.history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
 54
 55    def train_epoch(self, train_loader: DataLoader, criterion: nn.Module, optimizer: optim.Optimizer) -> tuple[float, float]:
 56        """
 57        Train the model for one epoch.
 58        
 59        Parameters
 60        -----------
 61        train_loader : DataLoader
 62            DataLoader for training data.
 63        criterion : nn.Module
 64            Loss function.
 65        optimizer : optim.Optimizer
 66            Optimizer for updating weights.
 67            
 68        Returns
 69        -----------
 70        tuple[float, float]
 71            Average training loss and accuracy for the epoch.
 72        """
 73        self.model.train()
 74        total_loss = 0
 75        correct = 0
 76        total = 0
 77        
 78        for data, target in train_loader:
 79            data, target = data.to(self.device), target.to(self.device)
 80            
 81            optimizer.zero_grad()
 82            output = self.model(data)
 83            loss = criterion(output, target)
 84            loss.backward()
 85            optimizer.step()
 86            
 87            total_loss += loss.item()
 88            pred = output.argmax(dim=1)
 89            correct += pred.eq(target).sum().item()
 90            total += target.size(0)
 91        
 92        return total_loss / len(train_loader), 100. * correct / total
 93
 94    def validate_epoch(self, val_loader: DataLoader, criterion: nn.Module) -> tuple[float, float]:
 95        """
 96        Validate the model for one epoch.
 97        
 98        Parameters
 99        -----------
100        val_loader : DataLoader
101            DataLoader for validation data.
102        criterion : nn.Module
103            Loss function.
104            
105        Returns
106        -----------
107        tuple[float, float]
108            Average validation loss and accuracy for the epoch.
109        """
110        self.model.eval()
111        total_loss = 0
112        correct = 0
113        total = 0
114        
115        with torch.no_grad():
116            for data, target in val_loader:
117                data, target = data.to(self.device), target.to(self.device)
118                output = self.model(data)
119                loss = criterion(output, target)
120                
121                total_loss += loss.item()
122                pred = output.argmax(dim=1)
123                correct += pred.eq(target).sum().item()
124                total += target.size(0)
125        
126        return total_loss / len(val_loader), 100. * correct / total
127
128    def train(self, train_loader: DataLoader, val_loader: DataLoader, epochs: int = 50, learning_rate: float = 0.001) -> dict:
129        """
130        Train the model for multiple epochs with early stopping.
131        
132        Parameters
133        -----------
134        train_loader : DataLoader
135            DataLoader for training data.
136        val_loader : DataLoader
137            DataLoader for validation data.
138        epochs : int
139            Maximum number of epochs to train.
140        learning_rate : float
141            Learning rate for the optimizer.
142            
143        Returns
144        -----------
145        dict
146            Training history containing losses and accuracies.
147            
148        Examples
149        ---------
150        >>> history = trainer.train(train_loader, val_loader, epochs=100, learning_rate=0.001)
151        >>> print(f"Best validation accuracy: {max(history['val_acc']):.2f}%")
152        """
153        criterion = nn.CrossEntropyLoss()
154        optimizer = optim.Adam(self.model.parameters(), lr=learning_rate, weight_decay=1e-4)
155        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
156        
157        best_val_acc = 0
158        patience_counter = 0
159        patience = 10
160        
161        print(f"Training on {self.device}")
162        print(f"Epochs: {epochs}, Learning Rate: {learning_rate}")
163        print("-" * 50)
164        
165        start_time = time.time()
166        
167        for epoch in range(epochs):
168            train_loss, train_acc = self.train_epoch(train_loader, criterion, optimizer)
169            val_loss, val_acc = self.validate_epoch(val_loader, criterion)
170            scheduler.step(val_loss)
171            
172            self.history['train_loss'].append(train_loss)
173            self.history['train_acc'].append(train_acc)
174            self.history['val_loss'].append(val_loss)
175            self.history['val_acc'].append(val_acc)
176            
177            if epoch % 10 == 0:
178                print(f'Epoch {epoch:3d}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.1f}%, '
179                      f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.1f}%')
180            
181            if val_acc > best_val_acc:
182                best_val_acc = val_acc
183                patience_counter = 0
184                self.best_model_state = self.model.state_dict().copy()
185            else:
186                patience_counter += 1
187                if patience_counter >= patience:
188                    print(f"Early stopping at epoch {epoch}")
189                    break
190        
191        self.model.load_state_dict(self.best_model_state)
192        training_time = time.time() - start_time
193        print(f"\nTraining completed in {training_time:.2f} seconds")
194        print(f"Best validation accuracy: {best_val_acc:.2f}%")
195        
196        return self.history

Class to handle training and validation of neural network models.

This trainer can be reused with different models and datasets, following the separation of concerns principle.

Parameters
  • model (nn.Module): The neural network model to be trained.
  • device (str): Device to run the training on ('cpu' or 'cuda').
Attributes
  • model (nn.Module): The neural network model to be trained.
  • device (str): Device to run the training on ('cpu' or 'cuda').
  • history (dict): Dictionary to store training history (losses and accuracies).
Examples
>>> from src.models import HepatitisNet
>>> from src.train import ModelTrainer
>>> model = HepatitisNet(input_size=12)
>>> trainer = ModelTrainer(model, device='cuda')
>>> history = trainer.train(train_loader, val_loader, epochs=50)
ModelTrainer(model: torch.nn.modules.module.Module, device: str = 'cpu')
50    def __init__(self, model: nn.Module, device: str = 'cpu'):
51        self.model = model.to(device)
52        self.device = device
53        self.history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
model
device
history
def train_epoch( self, train_loader: torch.utils.data.dataloader.DataLoader, criterion: torch.nn.modules.module.Module, optimizer: torch.optim.optimizer.Optimizer) -> tuple[float, float]:
55    def train_epoch(self, train_loader: DataLoader, criterion: nn.Module, optimizer: optim.Optimizer) -> tuple[float, float]:
56        """
57        Train the model for one epoch.
58        
59        Parameters
60        -----------
61        train_loader : DataLoader
62            DataLoader for training data.
63        criterion : nn.Module
64            Loss function.
65        optimizer : optim.Optimizer
66            Optimizer for updating weights.
67            
68        Returns
69        -----------
70        tuple[float, float]
71            Average training loss and accuracy for the epoch.
72        """
73        self.model.train()
74        total_loss = 0
75        correct = 0
76        total = 0
77        
78        for data, target in train_loader:
79            data, target = data.to(self.device), target.to(self.device)
80            
81            optimizer.zero_grad()
82            output = self.model(data)
83            loss = criterion(output, target)
84            loss.backward()
85            optimizer.step()
86            
87            total_loss += loss.item()
88            pred = output.argmax(dim=1)
89            correct += pred.eq(target).sum().item()
90            total += target.size(0)
91        
92        return total_loss / len(train_loader), 100. * correct / total

Train the model for one epoch.

Parameters
  • train_loader (DataLoader): DataLoader for training data.
  • criterion (nn.Module): Loss function.
  • optimizer (optim.Optimizer): Optimizer for updating weights.
Returns
  • tuple[float, float]: Average training loss and accuracy for the epoch.
def validate_epoch( self, val_loader: torch.utils.data.dataloader.DataLoader, criterion: torch.nn.modules.module.Module) -> tuple[float, float]:
 94    def validate_epoch(self, val_loader: DataLoader, criterion: nn.Module) -> tuple[float, float]:
 95        """
 96        Validate the model for one epoch.
 97        
 98        Parameters
 99        -----------
100        val_loader : DataLoader
101            DataLoader for validation data.
102        criterion : nn.Module
103            Loss function.
104            
105        Returns
106        -----------
107        tuple[float, float]
108            Average validation loss and accuracy for the epoch.
109        """
110        self.model.eval()
111        total_loss = 0
112        correct = 0
113        total = 0
114        
115        with torch.no_grad():
116            for data, target in val_loader:
117                data, target = data.to(self.device), target.to(self.device)
118                output = self.model(data)
119                loss = criterion(output, target)
120                
121                total_loss += loss.item()
122                pred = output.argmax(dim=1)
123                correct += pred.eq(target).sum().item()
124                total += target.size(0)
125        
126        return total_loss / len(val_loader), 100. * correct / total

Validate the model for one epoch.

Parameters
  • val_loader (DataLoader): DataLoader for validation data.
  • criterion (nn.Module): Loss function.
Returns
  • tuple[float, float]: Average validation loss and accuracy for the epoch.
def train( self, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, epochs: int = 50, learning_rate: float = 0.001) -> dict:
128    def train(self, train_loader: DataLoader, val_loader: DataLoader, epochs: int = 50, learning_rate: float = 0.001) -> dict:
129        """
130        Train the model for multiple epochs with early stopping.
131        
132        Parameters
133        -----------
134        train_loader : DataLoader
135            DataLoader for training data.
136        val_loader : DataLoader
137            DataLoader for validation data.
138        epochs : int
139            Maximum number of epochs to train.
140        learning_rate : float
141            Learning rate for the optimizer.
142            
143        Returns
144        -----------
145        dict
146            Training history containing losses and accuracies.
147            
148        Examples
149        ---------
150        >>> history = trainer.train(train_loader, val_loader, epochs=100, learning_rate=0.001)
151        >>> print(f"Best validation accuracy: {max(history['val_acc']):.2f}%")
152        """
153        criterion = nn.CrossEntropyLoss()
154        optimizer = optim.Adam(self.model.parameters(), lr=learning_rate, weight_decay=1e-4)
155        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
156        
157        best_val_acc = 0
158        patience_counter = 0
159        patience = 10
160        
161        print(f"Training on {self.device}")
162        print(f"Epochs: {epochs}, Learning Rate: {learning_rate}")
163        print("-" * 50)
164        
165        start_time = time.time()
166        
167        for epoch in range(epochs):
168            train_loss, train_acc = self.train_epoch(train_loader, criterion, optimizer)
169            val_loss, val_acc = self.validate_epoch(val_loader, criterion)
170            scheduler.step(val_loss)
171            
172            self.history['train_loss'].append(train_loss)
173            self.history['train_acc'].append(train_acc)
174            self.history['val_loss'].append(val_loss)
175            self.history['val_acc'].append(val_acc)
176            
177            if epoch % 10 == 0:
178                print(f'Epoch {epoch:3d}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.1f}%, '
179                      f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.1f}%')
180            
181            if val_acc > best_val_acc:
182                best_val_acc = val_acc
183                patience_counter = 0
184                self.best_model_state = self.model.state_dict().copy()
185            else:
186                patience_counter += 1
187                if patience_counter >= patience:
188                    print(f"Early stopping at epoch {epoch}")
189                    break
190        
191        self.model.load_state_dict(self.best_model_state)
192        training_time = time.time() - start_time
193        print(f"\nTraining completed in {training_time:.2f} seconds")
194        print(f"Best validation accuracy: {best_val_acc:.2f}%")
195        
196        return self.history

Train the model for multiple epochs with early stopping.

Parameters
  • train_loader (DataLoader): DataLoader for training data.
  • val_loader (DataLoader): DataLoader for validation data.
  • epochs (int): Maximum number of epochs to train.
  • learning_rate (float): Learning rate for the optimizer.
Returns
  • dict: Training history containing losses and accuracies.
Examples
>>> history = trainer.train(train_loader, val_loader, epochs=100, learning_rate=0.001)
>>> print(f"Best validation accuracy: {max(history['val_acc']):.2f}%")