src.visualization

  1from __future__ import annotations
  2import matplotlib.pyplot as plt
  3import seaborn as sns
  4import numpy as np
  5from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve
  6import pandas as pd
  7
  8plt.style.use('default')
  9sns.set_palette("inferno")
 10
 11def plot_data_overview(df: pd.DataFrame) -> plt.Figure:
 12    """
 13    Generate an overview of the dataset with key statistics and visualizations. \
 14    Includes a bar plot of disease categories, a histogram of age distribution  \
 15    of patients, a bar plot of sex distribution, and missing value counts per column.
 16
 17    Parameters
 18    ----------
 19    df: pd.DataFrame
 20        DataFrame containing the data
 21
 22    Returns
 23    -------
 24    fig: plt.Figure
 25        Matplotlib figure object containing the overview plots.
 26    """
 27
 28    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
 29    
 30    if 'Category' in df.columns:
 31        category_counts = df['Category'].value_counts()
 32        colors = sns.color_palette("inferno", len(category_counts))
 33        
 34        # Create a single stacked bar
 35        bottom = 0
 36        bar_width = 0.6
 37        for i, (category, count) in enumerate(category_counts.items()):
 38            axes[0,0].bar(['Total'], [count], bottom=bottom, color=colors[i], 
 39                         label=f'{category}: {count}', width=bar_width)
 40            bottom += count
 41        
 42        axes[0,0].set_title('Disease Categories')
 43        axes[0,0].set_xlabel('Population')
 44        axes[0,0].set_ylabel('Count')
 45        axes[0,0].legend(loc='lower left')
 46        axes[0,0].tick_params(axis='x', rotation=0)
 47    
 48    if 'Age' in df.columns:
 49        df['Age'].hist(bins=20, ax=axes[0,1], color=sns.color_palette("inferno", 10)[0])
 50        axes[0,1].set_title('Age Distribution')
 51        axes[0,1].set_xlabel('Age (years)')
 52        axes[0,1].set_ylabel('Frequency')
 53    
 54    if 'Sex' in df.columns:
 55        sex_counts = df['Sex'].value_counts()
 56        colors = sns.color_palette("inferno", len(sex_counts))
 57        
 58        # Create a single stacked bar
 59        bottom = 0
 60        bar_width = 0.6
 61        for i, (sex, count) in enumerate(sex_counts.items()):
 62            axes[1,0].bar(['Total'], [count], bottom=bottom, color=colors[i], 
 63                         label=f'{sex}: {count}', width=bar_width)
 64            bottom += count
 65        
 66        axes[1,0].set_title('Sex Distribution')
 67        axes[1,0].set_xlabel('Population')
 68        axes[1,0].set_ylabel('Count')
 69        axes[1,0].legend()
 70        axes[1,0].tick_params(axis='x', rotation=0)
 71    
 72    missing = df.isnull().sum()
 73    missing = missing[missing > 0]
 74    if len(missing) > 0:
 75        missing.plot(kind='bar', ax=axes[1,1], color=sns.color_palette("inferno", len(missing)))
 76        axes[1,1].set_title('Missing Values by Column')
 77        axes[1,1].set_ylabel('Count')
 78        axes[1,1].tick_params(axis='x', rotation=45)
 79        axes[1,1].yaxis.set_major_locator(plt.MaxNLocator(integer=True))
 80    else:
 81        axes[1,1].text(0.5, 0.5, 'No Missing Values', ha='center', va='center', transform=axes[1,1].transAxes)
 82        axes[1,1].set_title('Missing Values')
 83    
 84    plt.tight_layout()
 85    return fig
 86
 87def plot_correlation_matrix(df: pd.DataFrame) -> plt.Figure:
 88    """
 89    Plot a clustered correlation matrix for numeric features in the DataFrame.
 90    
 91    Parameters
 92    ----------
 93    df: pd.DataFrame
 94        DataFrame containing the data
 95
 96    Returns
 97    -------
 98    plt.Figure
 99        Matplotlib figure object containing the correlation matrix plot.
100    """
101
102    from scipy.cluster.hierarchy import linkage, dendrogram
103    from scipy.spatial.distance import squareform
104    
105    numeric_cols = df.select_dtypes(include=[np.number]).columns
106    
107    if len(numeric_cols) > 1:
108        plt.figure(figsize=(12, 10))
109        correlation_matrix = df[numeric_cols].corr()
110        
111        # Perform hierarchical clustering on the correlation matrix
112        # Convert correlation to distance (1 - |correlation|)
113        distance_matrix = 1 - np.abs(correlation_matrix)
114        condensed_distances = squareform(distance_matrix, checks=False)
115        linkage_matrix = linkage(condensed_distances, method='average')
116        
117        # Get the order from clustering
118        dendro = dendrogram(linkage_matrix, labels=correlation_matrix.columns, no_plot=True)
119        cluster_order = dendro['leaves']
120        
121        # Reorder the correlation matrix
122        ordered_corr = correlation_matrix.iloc[cluster_order, cluster_order]
123        
124        # Create mask for upper triangle
125        mask = np.triu(np.ones_like(ordered_corr, dtype=bool))
126        
127        sns.heatmap(ordered_corr, mask=mask, annot=True, cmap='inferno', 
128                   center=0, square=True, fmt='.2f')
129        plt.title('Feature Correlation Matrix (Clustered)')
130        plt.tight_layout()
131        return plt.gcf()
132    else:
133        print("Not enough numeric columns for correlation matrix")
134        return None
135
136def plot_feature_distributions(df: pd.DataFrame, target_col: str = 'target') -> plt.Figure:
137    """
138    Create histograms of feature distributions for each numeric feature, separated by target class.
139    
140    Parameters
141    ----------
142    df: pd.DataFrame
143        DataFrame containing the data
144    target_col: str
145        Name of the target column to separate classes. Default is 'target'.
146
147    Returns
148    -------
149    plt.Figure
150        Matplotlib figure object containing the feature distribution histograms.
151    """
152
153    numeric_cols = ['ALB', 'ALP', 'ALT', 'AST', 'BIL', 'CHE', 'CHOL', 'CREA', 'GGT', 'PROT']
154    available_cols = [col for col in numeric_cols if col in df.columns]
155    
156    if not available_cols:
157        print("No feature columns found")
158        return None
159    
160    n_cols = 5
161    n_rows = (len(available_cols) + n_cols - 1) // n_cols
162    
163    fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 4*n_rows))
164    axes = axes.flatten() if n_rows > 1 else [axes] if n_rows == 1 else axes
165    
166    for i, feature in enumerate(available_cols):
167        if i < len(axes):
168            if target_col in df.columns:
169                for target_value in df[target_col].unique():
170                    subset = df[df[target_col] == target_value][feature]
171                    label = 'Healthy' if target_value == 0 else 'Hepatitis C'
172                    axes[i].hist(subset, alpha=0.7, label=label, bins=20, color=sns.color_palette("inferno", 2)[target_value])
173                axes[i].set_title(f'{feature} Distribution')
174                axes[i].set_xlabel(feature)
175                axes[i].set_ylabel('Frequency')
176                axes[i].legend()
177            else:
178                df[feature].hist(bins=20, ax=axes[i])
179                axes[i].set_title(f'{feature} Distribution')
180    
181    for i in range(len(available_cols), len(axes)):
182        axes[i].set_visible(False)
183    
184    plt.tight_layout()
185    return fig
186
187def plot_violin_with_outliers(df, numeric_cols=None):
188    """
189    Create violin plots with overlaid box plots to show outliers for each numeric feature.
190    
191    Parameters
192    ----------
193    df: pd.DataFrame
194        DataFrame containing the data
195    numeric_cols: list[str]
196        List of column names to plot. If None, uses default hepatitis C features.
197
198    Returns
199    -------
200    plt.Figure
201        Matplotlib figure object containing the violin plots with outliers.
202    """
203    if numeric_cols is None:
204        numeric_cols = ['Age','ALB','ALP','ALT','AST','BIL','CHE','CHOL','CREA','GGT','PROT']
205    
206    # Filter columns that exist in the dataframe
207    available_cols = [col for col in numeric_cols if col in df.columns]
208    
209    if not available_cols:
210        print("No numeric columns found")
211        return None
212    
213    # Create subplots arranged horizontally
214    fig, axes = plt.subplots(1, len(available_cols), figsize=(20, 6))
215    colors = sns.color_palette("inferno", len(available_cols))
216    
217    # Handle case where there's only one column
218    if len(available_cols) == 1:
219        axes = [axes]
220    
221    for i, col in enumerate(available_cols):
222        # Create violin plot
223        sns.violinplot(y=df[col], ax=axes[i], color=colors[i], alpha=0.7)
224        # Add box plot to show outliers and quartiles
225        sns.boxplot(y=df[col], ax=axes[i], width=0.3, boxprops={'facecolor':'None'}, 
226                   showfliers=True, flierprops={'marker':'o', 'markersize':3, 'markerfacecolor':'red'})
227        axes[i].set_title(col, fontsize=12)
228        axes[i].set_xlabel('')
229        axes[i].set_ylabel('Values' if i == 0 else '')
230    
231    plt.suptitle("Violin Plots with Outliers for Each Feature", fontsize=16, y=1.02)
232    plt.tight_layout()
233    return fig
234
235def plot_training_history(history: dict) -> plt.Figure:
236    """
237    Plot training and validation loss and accuracy over epochs.
238
239    Parameters
240    ----------
241    history: dict
242        Dictionary containing training history with keys 'train_loss', 'val_loss', 'train_acc', 'val_acc'.
243
244    Returns
245    -------
246    plt.Figure
247        Matplotlib figure object containing the training history plots.
248    """
249
250    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
251
252    axes[0].plot(history['train_loss'], label='Training Loss', linewidth=2, color=sns.color_palette("inferno", 10)[0])
253    axes[0].plot(history['val_loss'], label='Validation Loss', linewidth=2, color=sns.color_palette("inferno", 10)[-1])
254    axes[0].set_title('Model Loss')
255    axes[0].set_xlabel('Epoch')
256    axes[0].set_ylabel('Loss')
257    axes[0].legend()
258    axes[0].grid(True, alpha=0.3)
259
260    axes[1].plot(history['train_acc'], label='Training Accuracy', linewidth=2, color=sns.color_palette("inferno", 10)[0])
261    axes[1].plot(history['val_acc'], label='Validation Accuracy', linewidth=2, color=sns.color_palette("inferno", 10)[-1])
262    axes[1].set_title('Model Accuracy')
263    axes[1].set_xlabel('Epoch')
264    axes[1].set_ylabel('Accuracy (%)')
265    axes[1].legend()
266    axes[1].grid(True, alpha=0.3)
267    
268    plt.tight_layout()
269    return fig
270
271def plot_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, class_names: list[str] = ['No Hepatitis C', 'Hepatitis C'], use_percentages: bool = True) -> plt.Figure:
272    """
273    Plot confusion matrix with option for percentages or absolute values.
274    
275    Parameters
276    -----------
277    y_true: np.ndarray
278        Ground truth labels.
279    y_pred: np.ndarray
280        Predicted labels
281    class_names: string
282        List of class names for labels
283    use_percentages: bool
284        If True, show percentages by true class; if False, show absolute counts
285
286    Returns
287    --------
288    plt.Figure
289        Matplotlib figure object containing the confusion matrix plot.
290    """
291    cm = confusion_matrix(y_true, y_pred)
292    
293    plt.figure(figsize=(8, 6))
294    
295    if use_percentages:
296        # Calculate percentage confusion matrix (row-wise normalization)
297        cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
298        sns.heatmap(cm_percent, annot=True, fmt='.1f', cmap='inferno',
299                    xticklabels=class_names, yticklabels=class_names)
300        plt.title('Confusion Matrix (Percentages)', fontsize=16)
301    else:
302        sns.heatmap(cm, annot=True, fmt='d', cmap='inferno',
303                    xticklabels=class_names, yticklabels=class_names)
304        plt.title('Confusion Matrix (Counts)', fontsize=16)
305    
306    plt.xlabel('Predicted', fontsize=12)
307    plt.ylabel('Actual', fontsize=12)
308
309    # Add accuracy information
310    accuracy = np.trace(cm) / np.sum(cm)
311    plt.figtext(0.1, 0.02, f'Overall Accuracy: {accuracy:.3f}', fontsize=12)
312    
313    plt.tight_layout()
314    return plt.gcf()
315
316def plot_roc_curve(y_true: np.ndarray, y_probs: np.ndarray) -> plt.Figure:
317    """
318    Plot ROC curve with AUC.
319
320    Parameters
321    -----------
322    y_true: np.ndarray
323        Ground truth binary labels.
324    y_probs: np.ndarray
325        Predicted probabilities for the positive class.
326
327    Returns
328    --------
329    plt.Figure
330        Matplotlib figure object containing the ROC curve plot.
331
332    Examples
333    ---------
334    >>> fig = plot_roc_curve(y_true, y_probs)
335    >>> fig.show()
336    """
337
338    fpr, tpr, _ = roc_curve(y_true, y_probs[:, 1])
339    roc_auc = auc(fpr, tpr)
340    
341    plt.figure(figsize=(8, 6))
342    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC Curve (AUC = {roc_auc:.3f})')
343    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')
344    plt.xlim([0.0, 1.0])
345    plt.ylim([0.0, 1.05])
346    plt.xlabel('False Positive Rate')
347    plt.ylabel('True Positive Rate')
348    plt.title('Receiver Operating Characteristic (ROC) Curve')
349    plt.legend(loc="lower right")
350    plt.grid(True, alpha=0.3)
351    plt.tight_layout()
352    return plt.gcf()
353
354def plot_precision_recall_curve(y_true: np.ndarray, y_probs: np.ndarray) -> plt.Figure:
355    """
356    Plot precision-recall curve with AUC.
357
358    Parameters
359    -----------
360    y_true: np.ndarray
361        Ground truth binary labels.
362    y_probs: np.ndarray
363        Predicted probabilities for the positive class.
364
365    Returns
366    --------
367    plt.Figure
368        Matplotlib figure object containing the precision-recall curve plot.
369
370    Examples
371    ---------
372    >>> fig = plot_precision_recall_curve(y_true, y_probs)
373    >>> fig.show()
374    """
375    precision, recall, _ = precision_recall_curve(y_true, y_probs[:, 1])
376    pr_auc = auc(recall, precision)
377    
378    plt.figure(figsize=(8, 6))
379    plt.plot(recall, precision, color='blue', lw=2, label=f'PR Curve (AUC = {pr_auc:.3f})')
380    plt.xlim([0.0, 1.0])
381    plt.ylim([0.0, 1.05])
382    plt.xlabel('Recall')
383    plt.ylabel('Precision')
384    plt.title('Precision-Recall Curve')
385    plt.legend(loc="lower left")
386    plt.grid(True, alpha=0.3)
387    plt.tight_layout()
388    return plt.gcf()
389
390def plot_prediction_confidence(y_true: np.ndarray, y_probs: np.ndarray, class_names: list[str] = ['Healthy', 'Hepatitis C']) -> plt.Figure:
391    """
392    Plot histograms of prediction confidence for each class, separated by predicted and true labels.
393
394    Parameters
395    -----------
396    y_true: np.ndarray
397        Ground truth binary labels.
398    y_probs: np.ndarray
399        Predicted probabilities for each class (shape: [n_samples, n_classes]).
400    class_names: list of str
401        Names of the classes for labeling.
402
403    Returns
404    --------
405    plt.Figure
406        Matplotlib figure object containing the prediction confidence histograms.
407
408    Examples
409    ---------
410    >>> fig = plot_prediction_confidence(y_true, y_probs)
411    >>> fig.show()
412    """
413    
414    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
415    y_pred = np.argmax(y_probs, axis=1)
416    max_probs = np.max(y_probs, axis=1)
417    
418    for class_idx in [0, 1]:
419        mask = y_pred == class_idx
420        if np.any(mask):
421            axes[0].hist(max_probs[mask], bins=20, alpha=0.7, 
422                        label=f'Predicted: {class_names[class_idx]}')
423    
424    axes[0].set_title('Prediction Confidence by Predicted Class')
425    axes[0].set_xlabel('Confidence')
426    axes[0].set_ylabel('Frequency')
427    axes[0].legend()
428    axes[0].grid(True, alpha=0.3)
429
430    for class_idx in [0, 1]:
431        mask = y_true == class_idx
432        if np.any(mask):
433            axes[1].hist(max_probs[mask], bins=20, alpha=0.7, 
434                        label=f'True: {class_names[class_idx]}')
435    
436    axes[1].set_title('Prediction Confidence by True Class')
437    axes[1].set_xlabel('Confidence')
438    axes[1].set_ylabel('Frequency')
439    axes[1].legend()
440    axes[1].grid(True, alpha=0.3)
441    
442    plt.tight_layout()
443    return fig
def plot_data_overview(df: pandas.core.frame.DataFrame) -> matplotlib.figure.Figure:
12def plot_data_overview(df: pd.DataFrame) -> plt.Figure:
13    """
14    Generate an overview of the dataset with key statistics and visualizations. \
15    Includes a bar plot of disease categories, a histogram of age distribution  \
16    of patients, a bar plot of sex distribution, and missing value counts per column.
17
18    Parameters
19    ----------
20    df: pd.DataFrame
21        DataFrame containing the data
22
23    Returns
24    -------
25    fig: plt.Figure
26        Matplotlib figure object containing the overview plots.
27    """
28
29    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
30    
31    if 'Category' in df.columns:
32        category_counts = df['Category'].value_counts()
33        colors = sns.color_palette("inferno", len(category_counts))
34        
35        # Create a single stacked bar
36        bottom = 0
37        bar_width = 0.6
38        for i, (category, count) in enumerate(category_counts.items()):
39            axes[0,0].bar(['Total'], [count], bottom=bottom, color=colors[i], 
40                         label=f'{category}: {count}', width=bar_width)
41            bottom += count
42        
43        axes[0,0].set_title('Disease Categories')
44        axes[0,0].set_xlabel('Population')
45        axes[0,0].set_ylabel('Count')
46        axes[0,0].legend(loc='lower left')
47        axes[0,0].tick_params(axis='x', rotation=0)
48    
49    if 'Age' in df.columns:
50        df['Age'].hist(bins=20, ax=axes[0,1], color=sns.color_palette("inferno", 10)[0])
51        axes[0,1].set_title('Age Distribution')
52        axes[0,1].set_xlabel('Age (years)')
53        axes[0,1].set_ylabel('Frequency')
54    
55    if 'Sex' in df.columns:
56        sex_counts = df['Sex'].value_counts()
57        colors = sns.color_palette("inferno", len(sex_counts))
58        
59        # Create a single stacked bar
60        bottom = 0
61        bar_width = 0.6
62        for i, (sex, count) in enumerate(sex_counts.items()):
63            axes[1,0].bar(['Total'], [count], bottom=bottom, color=colors[i], 
64                         label=f'{sex}: {count}', width=bar_width)
65            bottom += count
66        
67        axes[1,0].set_title('Sex Distribution')
68        axes[1,0].set_xlabel('Population')
69        axes[1,0].set_ylabel('Count')
70        axes[1,0].legend()
71        axes[1,0].tick_params(axis='x', rotation=0)
72    
73    missing = df.isnull().sum()
74    missing = missing[missing > 0]
75    if len(missing) > 0:
76        missing.plot(kind='bar', ax=axes[1,1], color=sns.color_palette("inferno", len(missing)))
77        axes[1,1].set_title('Missing Values by Column')
78        axes[1,1].set_ylabel('Count')
79        axes[1,1].tick_params(axis='x', rotation=45)
80        axes[1,1].yaxis.set_major_locator(plt.MaxNLocator(integer=True))
81    else:
82        axes[1,1].text(0.5, 0.5, 'No Missing Values', ha='center', va='center', transform=axes[1,1].transAxes)
83        axes[1,1].set_title('Missing Values')
84    
85    plt.tight_layout()
86    return fig

Generate an overview of the dataset with key statistics and visualizations. Includes a bar plot of disease categories, a histogram of age distribution of patients, a bar plot of sex distribution, and missing value counts per column.

Parameters
  • df (pd.DataFrame): DataFrame containing the data
Returns
  • fig (plt.Figure): Matplotlib figure object containing the overview plots.
def plot_correlation_matrix(df: pandas.core.frame.DataFrame) -> matplotlib.figure.Figure:
 88def plot_correlation_matrix(df: pd.DataFrame) -> plt.Figure:
 89    """
 90    Plot a clustered correlation matrix for numeric features in the DataFrame.
 91    
 92    Parameters
 93    ----------
 94    df: pd.DataFrame
 95        DataFrame containing the data
 96
 97    Returns
 98    -------
 99    plt.Figure
100        Matplotlib figure object containing the correlation matrix plot.
101    """
102
103    from scipy.cluster.hierarchy import linkage, dendrogram
104    from scipy.spatial.distance import squareform
105    
106    numeric_cols = df.select_dtypes(include=[np.number]).columns
107    
108    if len(numeric_cols) > 1:
109        plt.figure(figsize=(12, 10))
110        correlation_matrix = df[numeric_cols].corr()
111        
112        # Perform hierarchical clustering on the correlation matrix
113        # Convert correlation to distance (1 - |correlation|)
114        distance_matrix = 1 - np.abs(correlation_matrix)
115        condensed_distances = squareform(distance_matrix, checks=False)
116        linkage_matrix = linkage(condensed_distances, method='average')
117        
118        # Get the order from clustering
119        dendro = dendrogram(linkage_matrix, labels=correlation_matrix.columns, no_plot=True)
120        cluster_order = dendro['leaves']
121        
122        # Reorder the correlation matrix
123        ordered_corr = correlation_matrix.iloc[cluster_order, cluster_order]
124        
125        # Create mask for upper triangle
126        mask = np.triu(np.ones_like(ordered_corr, dtype=bool))
127        
128        sns.heatmap(ordered_corr, mask=mask, annot=True, cmap='inferno', 
129                   center=0, square=True, fmt='.2f')
130        plt.title('Feature Correlation Matrix (Clustered)')
131        plt.tight_layout()
132        return plt.gcf()
133    else:
134        print("Not enough numeric columns for correlation matrix")
135        return None

Plot a clustered correlation matrix for numeric features in the DataFrame.

Parameters
  • df (pd.DataFrame): DataFrame containing the data
Returns
  • plt.Figure: Matplotlib figure object containing the correlation matrix plot.
def plot_feature_distributions( df: pandas.core.frame.DataFrame, target_col: str = 'target') -> matplotlib.figure.Figure:
137def plot_feature_distributions(df: pd.DataFrame, target_col: str = 'target') -> plt.Figure:
138    """
139    Create histograms of feature distributions for each numeric feature, separated by target class.
140    
141    Parameters
142    ----------
143    df: pd.DataFrame
144        DataFrame containing the data
145    target_col: str
146        Name of the target column to separate classes. Default is 'target'.
147
148    Returns
149    -------
150    plt.Figure
151        Matplotlib figure object containing the feature distribution histograms.
152    """
153
154    numeric_cols = ['ALB', 'ALP', 'ALT', 'AST', 'BIL', 'CHE', 'CHOL', 'CREA', 'GGT', 'PROT']
155    available_cols = [col for col in numeric_cols if col in df.columns]
156    
157    if not available_cols:
158        print("No feature columns found")
159        return None
160    
161    n_cols = 5
162    n_rows = (len(available_cols) + n_cols - 1) // n_cols
163    
164    fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 4*n_rows))
165    axes = axes.flatten() if n_rows > 1 else [axes] if n_rows == 1 else axes
166    
167    for i, feature in enumerate(available_cols):
168        if i < len(axes):
169            if target_col in df.columns:
170                for target_value in df[target_col].unique():
171                    subset = df[df[target_col] == target_value][feature]
172                    label = 'Healthy' if target_value == 0 else 'Hepatitis C'
173                    axes[i].hist(subset, alpha=0.7, label=label, bins=20, color=sns.color_palette("inferno", 2)[target_value])
174                axes[i].set_title(f'{feature} Distribution')
175                axes[i].set_xlabel(feature)
176                axes[i].set_ylabel('Frequency')
177                axes[i].legend()
178            else:
179                df[feature].hist(bins=20, ax=axes[i])
180                axes[i].set_title(f'{feature} Distribution')
181    
182    for i in range(len(available_cols), len(axes)):
183        axes[i].set_visible(False)
184    
185    plt.tight_layout()
186    return fig

Create histograms of feature distributions for each numeric feature, separated by target class.

Parameters
  • df (pd.DataFrame): DataFrame containing the data
  • target_col (str): Name of the target column to separate classes. Default is 'target'.
Returns
  • plt.Figure: Matplotlib figure object containing the feature distribution histograms.
def plot_violin_with_outliers(df, numeric_cols=None):
188def plot_violin_with_outliers(df, numeric_cols=None):
189    """
190    Create violin plots with overlaid box plots to show outliers for each numeric feature.
191    
192    Parameters
193    ----------
194    df: pd.DataFrame
195        DataFrame containing the data
196    numeric_cols: list[str]
197        List of column names to plot. If None, uses default hepatitis C features.
198
199    Returns
200    -------
201    plt.Figure
202        Matplotlib figure object containing the violin plots with outliers.
203    """
204    if numeric_cols is None:
205        numeric_cols = ['Age','ALB','ALP','ALT','AST','BIL','CHE','CHOL','CREA','GGT','PROT']
206    
207    # Filter columns that exist in the dataframe
208    available_cols = [col for col in numeric_cols if col in df.columns]
209    
210    if not available_cols:
211        print("No numeric columns found")
212        return None
213    
214    # Create subplots arranged horizontally
215    fig, axes = plt.subplots(1, len(available_cols), figsize=(20, 6))
216    colors = sns.color_palette("inferno", len(available_cols))
217    
218    # Handle case where there's only one column
219    if len(available_cols) == 1:
220        axes = [axes]
221    
222    for i, col in enumerate(available_cols):
223        # Create violin plot
224        sns.violinplot(y=df[col], ax=axes[i], color=colors[i], alpha=0.7)
225        # Add box plot to show outliers and quartiles
226        sns.boxplot(y=df[col], ax=axes[i], width=0.3, boxprops={'facecolor':'None'}, 
227                   showfliers=True, flierprops={'marker':'o', 'markersize':3, 'markerfacecolor':'red'})
228        axes[i].set_title(col, fontsize=12)
229        axes[i].set_xlabel('')
230        axes[i].set_ylabel('Values' if i == 0 else '')
231    
232    plt.suptitle("Violin Plots with Outliers for Each Feature", fontsize=16, y=1.02)
233    plt.tight_layout()
234    return fig

Create violin plots with overlaid box plots to show outliers for each numeric feature.

Parameters
  • df (pd.DataFrame): DataFrame containing the data
  • numeric_cols (list[str]): List of column names to plot. If None, uses default hepatitis C features.
Returns
  • plt.Figure: Matplotlib figure object containing the violin plots with outliers.
def plot_training_history(history: dict) -> matplotlib.figure.Figure:
236def plot_training_history(history: dict) -> plt.Figure:
237    """
238    Plot training and validation loss and accuracy over epochs.
239
240    Parameters
241    ----------
242    history: dict
243        Dictionary containing training history with keys 'train_loss', 'val_loss', 'train_acc', 'val_acc'.
244
245    Returns
246    -------
247    plt.Figure
248        Matplotlib figure object containing the training history plots.
249    """
250
251    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
252
253    axes[0].plot(history['train_loss'], label='Training Loss', linewidth=2, color=sns.color_palette("inferno", 10)[0])
254    axes[0].plot(history['val_loss'], label='Validation Loss', linewidth=2, color=sns.color_palette("inferno", 10)[-1])
255    axes[0].set_title('Model Loss')
256    axes[0].set_xlabel('Epoch')
257    axes[0].set_ylabel('Loss')
258    axes[0].legend()
259    axes[0].grid(True, alpha=0.3)
260
261    axes[1].plot(history['train_acc'], label='Training Accuracy', linewidth=2, color=sns.color_palette("inferno", 10)[0])
262    axes[1].plot(history['val_acc'], label='Validation Accuracy', linewidth=2, color=sns.color_palette("inferno", 10)[-1])
263    axes[1].set_title('Model Accuracy')
264    axes[1].set_xlabel('Epoch')
265    axes[1].set_ylabel('Accuracy (%)')
266    axes[1].legend()
267    axes[1].grid(True, alpha=0.3)
268    
269    plt.tight_layout()
270    return fig

Plot training and validation loss and accuracy over epochs.

Parameters
  • history (dict): Dictionary containing training history with keys 'train_loss', 'val_loss', 'train_acc', 'val_acc'.
Returns
  • plt.Figure: Matplotlib figure object containing the training history plots.
def plot_confusion_matrix( y_true: numpy.ndarray, y_pred: numpy.ndarray, class_names: list[str] = ['No Hepatitis C', 'Hepatitis C'], use_percentages: bool = True) -> matplotlib.figure.Figure:
272def plot_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, class_names: list[str] = ['No Hepatitis C', 'Hepatitis C'], use_percentages: bool = True) -> plt.Figure:
273    """
274    Plot confusion matrix with option for percentages or absolute values.
275    
276    Parameters
277    -----------
278    y_true: np.ndarray
279        Ground truth labels.
280    y_pred: np.ndarray
281        Predicted labels
282    class_names: string
283        List of class names for labels
284    use_percentages: bool
285        If True, show percentages by true class; if False, show absolute counts
286
287    Returns
288    --------
289    plt.Figure
290        Matplotlib figure object containing the confusion matrix plot.
291    """
292    cm = confusion_matrix(y_true, y_pred)
293    
294    plt.figure(figsize=(8, 6))
295    
296    if use_percentages:
297        # Calculate percentage confusion matrix (row-wise normalization)
298        cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
299        sns.heatmap(cm_percent, annot=True, fmt='.1f', cmap='inferno',
300                    xticklabels=class_names, yticklabels=class_names)
301        plt.title('Confusion Matrix (Percentages)', fontsize=16)
302    else:
303        sns.heatmap(cm, annot=True, fmt='d', cmap='inferno',
304                    xticklabels=class_names, yticklabels=class_names)
305        plt.title('Confusion Matrix (Counts)', fontsize=16)
306    
307    plt.xlabel('Predicted', fontsize=12)
308    plt.ylabel('Actual', fontsize=12)
309
310    # Add accuracy information
311    accuracy = np.trace(cm) / np.sum(cm)
312    plt.figtext(0.1, 0.02, f'Overall Accuracy: {accuracy:.3f}', fontsize=12)
313    
314    plt.tight_layout()
315    return plt.gcf()

Plot confusion matrix with option for percentages or absolute values.

Parameters
  • y_true (np.ndarray): Ground truth labels.
  • y_pred (np.ndarray): Predicted labels
  • class_names (string): List of class names for labels
  • use_percentages (bool): If True, show percentages by true class; if False, show absolute counts
Returns
  • plt.Figure: Matplotlib figure object containing the confusion matrix plot.
def plot_roc_curve( y_true: numpy.ndarray, y_probs: numpy.ndarray) -> matplotlib.figure.Figure:
317def plot_roc_curve(y_true: np.ndarray, y_probs: np.ndarray) -> plt.Figure:
318    """
319    Plot ROC curve with AUC.
320
321    Parameters
322    -----------
323    y_true: np.ndarray
324        Ground truth binary labels.
325    y_probs: np.ndarray
326        Predicted probabilities for the positive class.
327
328    Returns
329    --------
330    plt.Figure
331        Matplotlib figure object containing the ROC curve plot.
332
333    Examples
334    ---------
335    >>> fig = plot_roc_curve(y_true, y_probs)
336    >>> fig.show()
337    """
338
339    fpr, tpr, _ = roc_curve(y_true, y_probs[:, 1])
340    roc_auc = auc(fpr, tpr)
341    
342    plt.figure(figsize=(8, 6))
343    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC Curve (AUC = {roc_auc:.3f})')
344    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')
345    plt.xlim([0.0, 1.0])
346    plt.ylim([0.0, 1.05])
347    plt.xlabel('False Positive Rate')
348    plt.ylabel('True Positive Rate')
349    plt.title('Receiver Operating Characteristic (ROC) Curve')
350    plt.legend(loc="lower right")
351    plt.grid(True, alpha=0.3)
352    plt.tight_layout()
353    return plt.gcf()

Plot ROC curve with AUC.

Parameters
  • y_true (np.ndarray): Ground truth binary labels.
  • y_probs (np.ndarray): Predicted probabilities for the positive class.
Returns
  • plt.Figure: Matplotlib figure object containing the ROC curve plot.
Examples
>>> fig = plot_roc_curve(y_true, y_probs)
>>> fig.show()
def plot_precision_recall_curve( y_true: numpy.ndarray, y_probs: numpy.ndarray) -> matplotlib.figure.Figure:
355def plot_precision_recall_curve(y_true: np.ndarray, y_probs: np.ndarray) -> plt.Figure:
356    """
357    Plot precision-recall curve with AUC.
358
359    Parameters
360    -----------
361    y_true: np.ndarray
362        Ground truth binary labels.
363    y_probs: np.ndarray
364        Predicted probabilities for the positive class.
365
366    Returns
367    --------
368    plt.Figure
369        Matplotlib figure object containing the precision-recall curve plot.
370
371    Examples
372    ---------
373    >>> fig = plot_precision_recall_curve(y_true, y_probs)
374    >>> fig.show()
375    """
376    precision, recall, _ = precision_recall_curve(y_true, y_probs[:, 1])
377    pr_auc = auc(recall, precision)
378    
379    plt.figure(figsize=(8, 6))
380    plt.plot(recall, precision, color='blue', lw=2, label=f'PR Curve (AUC = {pr_auc:.3f})')
381    plt.xlim([0.0, 1.0])
382    plt.ylim([0.0, 1.05])
383    plt.xlabel('Recall')
384    plt.ylabel('Precision')
385    plt.title('Precision-Recall Curve')
386    plt.legend(loc="lower left")
387    plt.grid(True, alpha=0.3)
388    plt.tight_layout()
389    return plt.gcf()

Plot precision-recall curve with AUC.

Parameters
  • y_true (np.ndarray): Ground truth binary labels.
  • y_probs (np.ndarray): Predicted probabilities for the positive class.
Returns
  • plt.Figure: Matplotlib figure object containing the precision-recall curve plot.
Examples
>>> fig = plot_precision_recall_curve(y_true, y_probs)
>>> fig.show()
def plot_prediction_confidence( y_true: numpy.ndarray, y_probs: numpy.ndarray, class_names: list[str] = ['Healthy', 'Hepatitis C']) -> matplotlib.figure.Figure:
391def plot_prediction_confidence(y_true: np.ndarray, y_probs: np.ndarray, class_names: list[str] = ['Healthy', 'Hepatitis C']) -> plt.Figure:
392    """
393    Plot histograms of prediction confidence for each class, separated by predicted and true labels.
394
395    Parameters
396    -----------
397    y_true: np.ndarray
398        Ground truth binary labels.
399    y_probs: np.ndarray
400        Predicted probabilities for each class (shape: [n_samples, n_classes]).
401    class_names: list of str
402        Names of the classes for labeling.
403
404    Returns
405    --------
406    plt.Figure
407        Matplotlib figure object containing the prediction confidence histograms.
408
409    Examples
410    ---------
411    >>> fig = plot_prediction_confidence(y_true, y_probs)
412    >>> fig.show()
413    """
414    
415    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
416    y_pred = np.argmax(y_probs, axis=1)
417    max_probs = np.max(y_probs, axis=1)
418    
419    for class_idx in [0, 1]:
420        mask = y_pred == class_idx
421        if np.any(mask):
422            axes[0].hist(max_probs[mask], bins=20, alpha=0.7, 
423                        label=f'Predicted: {class_names[class_idx]}')
424    
425    axes[0].set_title('Prediction Confidence by Predicted Class')
426    axes[0].set_xlabel('Confidence')
427    axes[0].set_ylabel('Frequency')
428    axes[0].legend()
429    axes[0].grid(True, alpha=0.3)
430
431    for class_idx in [0, 1]:
432        mask = y_true == class_idx
433        if np.any(mask):
434            axes[1].hist(max_probs[mask], bins=20, alpha=0.7, 
435                        label=f'True: {class_names[class_idx]}')
436    
437    axes[1].set_title('Prediction Confidence by True Class')
438    axes[1].set_xlabel('Confidence')
439    axes[1].set_ylabel('Frequency')
440    axes[1].legend()
441    axes[1].grid(True, alpha=0.3)
442    
443    plt.tight_layout()
444    return fig

Plot histograms of prediction confidence for each class, separated by predicted and true labels.

Parameters
  • y_true (np.ndarray): Ground truth binary labels.
  • y_probs (np.ndarray): Predicted probabilities for each class (shape: [n_samples, n_classes]).
  • class_names (list of str): Names of the classes for labeling.
Returns
  • plt.Figure: Matplotlib figure object containing the prediction confidence histograms.
Examples
>>> fig = plot_prediction_confidence(y_true, y_probs)
>>> fig.show()