src.visualization
1from __future__ import annotations 2import matplotlib.pyplot as plt 3import seaborn as sns 4import numpy as np 5from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve 6import pandas as pd 7 8plt.style.use('default') 9sns.set_palette("inferno") 10 11def plot_data_overview(df: pd.DataFrame) -> plt.Figure: 12 """ 13 Generate an overview of the dataset with key statistics and visualizations. \ 14 Includes a bar plot of disease categories, a histogram of age distribution \ 15 of patients, a bar plot of sex distribution, and missing value counts per column. 16 17 Parameters 18 ---------- 19 df: pd.DataFrame 20 DataFrame containing the data 21 22 Returns 23 ------- 24 fig: plt.Figure 25 Matplotlib figure object containing the overview plots. 26 """ 27 28 fig, axes = plt.subplots(2, 2, figsize=(15, 10)) 29 30 if 'Category' in df.columns: 31 category_counts = df['Category'].value_counts() 32 colors = sns.color_palette("inferno", len(category_counts)) 33 34 # Create a single stacked bar 35 bottom = 0 36 bar_width = 0.6 37 for i, (category, count) in enumerate(category_counts.items()): 38 axes[0,0].bar(['Total'], [count], bottom=bottom, color=colors[i], 39 label=f'{category}: {count}', width=bar_width) 40 bottom += count 41 42 axes[0,0].set_title('Disease Categories') 43 axes[0,0].set_xlabel('Population') 44 axes[0,0].set_ylabel('Count') 45 axes[0,0].legend(loc='lower left') 46 axes[0,0].tick_params(axis='x', rotation=0) 47 48 if 'Age' in df.columns: 49 df['Age'].hist(bins=20, ax=axes[0,1], color=sns.color_palette("inferno", 10)[0]) 50 axes[0,1].set_title('Age Distribution') 51 axes[0,1].set_xlabel('Age (years)') 52 axes[0,1].set_ylabel('Frequency') 53 54 if 'Sex' in df.columns: 55 sex_counts = df['Sex'].value_counts() 56 colors = sns.color_palette("inferno", len(sex_counts)) 57 58 # Create a single stacked bar 59 bottom = 0 60 bar_width = 0.6 61 for i, (sex, count) in enumerate(sex_counts.items()): 62 axes[1,0].bar(['Total'], [count], bottom=bottom, color=colors[i], 63 label=f'{sex}: {count}', width=bar_width) 64 bottom += count 65 66 axes[1,0].set_title('Sex Distribution') 67 axes[1,0].set_xlabel('Population') 68 axes[1,0].set_ylabel('Count') 69 axes[1,0].legend() 70 axes[1,0].tick_params(axis='x', rotation=0) 71 72 missing = df.isnull().sum() 73 missing = missing[missing > 0] 74 if len(missing) > 0: 75 missing.plot(kind='bar', ax=axes[1,1], color=sns.color_palette("inferno", len(missing))) 76 axes[1,1].set_title('Missing Values by Column') 77 axes[1,1].set_ylabel('Count') 78 axes[1,1].tick_params(axis='x', rotation=45) 79 axes[1,1].yaxis.set_major_locator(plt.MaxNLocator(integer=True)) 80 else: 81 axes[1,1].text(0.5, 0.5, 'No Missing Values', ha='center', va='center', transform=axes[1,1].transAxes) 82 axes[1,1].set_title('Missing Values') 83 84 plt.tight_layout() 85 return fig 86 87def plot_correlation_matrix(df: pd.DataFrame) -> plt.Figure: 88 """ 89 Plot a clustered correlation matrix for numeric features in the DataFrame. 90 91 Parameters 92 ---------- 93 df: pd.DataFrame 94 DataFrame containing the data 95 96 Returns 97 ------- 98 plt.Figure 99 Matplotlib figure object containing the correlation matrix plot. 100 """ 101 102 from scipy.cluster.hierarchy import linkage, dendrogram 103 from scipy.spatial.distance import squareform 104 105 numeric_cols = df.select_dtypes(include=[np.number]).columns 106 107 if len(numeric_cols) > 1: 108 plt.figure(figsize=(12, 10)) 109 correlation_matrix = df[numeric_cols].corr() 110 111 # Perform hierarchical clustering on the correlation matrix 112 # Convert correlation to distance (1 - |correlation|) 113 distance_matrix = 1 - np.abs(correlation_matrix) 114 condensed_distances = squareform(distance_matrix, checks=False) 115 linkage_matrix = linkage(condensed_distances, method='average') 116 117 # Get the order from clustering 118 dendro = dendrogram(linkage_matrix, labels=correlation_matrix.columns, no_plot=True) 119 cluster_order = dendro['leaves'] 120 121 # Reorder the correlation matrix 122 ordered_corr = correlation_matrix.iloc[cluster_order, cluster_order] 123 124 # Create mask for upper triangle 125 mask = np.triu(np.ones_like(ordered_corr, dtype=bool)) 126 127 sns.heatmap(ordered_corr, mask=mask, annot=True, cmap='inferno', 128 center=0, square=True, fmt='.2f') 129 plt.title('Feature Correlation Matrix (Clustered)') 130 plt.tight_layout() 131 return plt.gcf() 132 else: 133 print("Not enough numeric columns for correlation matrix") 134 return None 135 136def plot_feature_distributions(df: pd.DataFrame, target_col: str = 'target') -> plt.Figure: 137 """ 138 Create histograms of feature distributions for each numeric feature, separated by target class. 139 140 Parameters 141 ---------- 142 df: pd.DataFrame 143 DataFrame containing the data 144 target_col: str 145 Name of the target column to separate classes. Default is 'target'. 146 147 Returns 148 ------- 149 plt.Figure 150 Matplotlib figure object containing the feature distribution histograms. 151 """ 152 153 numeric_cols = ['ALB', 'ALP', 'ALT', 'AST', 'BIL', 'CHE', 'CHOL', 'CREA', 'GGT', 'PROT'] 154 available_cols = [col for col in numeric_cols if col in df.columns] 155 156 if not available_cols: 157 print("No feature columns found") 158 return None 159 160 n_cols = 5 161 n_rows = (len(available_cols) + n_cols - 1) // n_cols 162 163 fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 4*n_rows)) 164 axes = axes.flatten() if n_rows > 1 else [axes] if n_rows == 1 else axes 165 166 for i, feature in enumerate(available_cols): 167 if i < len(axes): 168 if target_col in df.columns: 169 for target_value in df[target_col].unique(): 170 subset = df[df[target_col] == target_value][feature] 171 label = 'Healthy' if target_value == 0 else 'Hepatitis C' 172 axes[i].hist(subset, alpha=0.7, label=label, bins=20, color=sns.color_palette("inferno", 2)[target_value]) 173 axes[i].set_title(f'{feature} Distribution') 174 axes[i].set_xlabel(feature) 175 axes[i].set_ylabel('Frequency') 176 axes[i].legend() 177 else: 178 df[feature].hist(bins=20, ax=axes[i]) 179 axes[i].set_title(f'{feature} Distribution') 180 181 for i in range(len(available_cols), len(axes)): 182 axes[i].set_visible(False) 183 184 plt.tight_layout() 185 return fig 186 187def plot_violin_with_outliers(df, numeric_cols=None): 188 """ 189 Create violin plots with overlaid box plots to show outliers for each numeric feature. 190 191 Parameters 192 ---------- 193 df: pd.DataFrame 194 DataFrame containing the data 195 numeric_cols: list[str] 196 List of column names to plot. If None, uses default hepatitis C features. 197 198 Returns 199 ------- 200 plt.Figure 201 Matplotlib figure object containing the violin plots with outliers. 202 """ 203 if numeric_cols is None: 204 numeric_cols = ['Age','ALB','ALP','ALT','AST','BIL','CHE','CHOL','CREA','GGT','PROT'] 205 206 # Filter columns that exist in the dataframe 207 available_cols = [col for col in numeric_cols if col in df.columns] 208 209 if not available_cols: 210 print("No numeric columns found") 211 return None 212 213 # Create subplots arranged horizontally 214 fig, axes = plt.subplots(1, len(available_cols), figsize=(20, 6)) 215 colors = sns.color_palette("inferno", len(available_cols)) 216 217 # Handle case where there's only one column 218 if len(available_cols) == 1: 219 axes = [axes] 220 221 for i, col in enumerate(available_cols): 222 # Create violin plot 223 sns.violinplot(y=df[col], ax=axes[i], color=colors[i], alpha=0.7) 224 # Add box plot to show outliers and quartiles 225 sns.boxplot(y=df[col], ax=axes[i], width=0.3, boxprops={'facecolor':'None'}, 226 showfliers=True, flierprops={'marker':'o', 'markersize':3, 'markerfacecolor':'red'}) 227 axes[i].set_title(col, fontsize=12) 228 axes[i].set_xlabel('') 229 axes[i].set_ylabel('Values' if i == 0 else '') 230 231 plt.suptitle("Violin Plots with Outliers for Each Feature", fontsize=16, y=1.02) 232 plt.tight_layout() 233 return fig 234 235def plot_training_history(history: dict) -> plt.Figure: 236 """ 237 Plot training and validation loss and accuracy over epochs. 238 239 Parameters 240 ---------- 241 history: dict 242 Dictionary containing training history with keys 'train_loss', 'val_loss', 'train_acc', 'val_acc'. 243 244 Returns 245 ------- 246 plt.Figure 247 Matplotlib figure object containing the training history plots. 248 """ 249 250 fig, axes = plt.subplots(1, 2, figsize=(15, 5)) 251 252 axes[0].plot(history['train_loss'], label='Training Loss', linewidth=2, color=sns.color_palette("inferno", 10)[0]) 253 axes[0].plot(history['val_loss'], label='Validation Loss', linewidth=2, color=sns.color_palette("inferno", 10)[-1]) 254 axes[0].set_title('Model Loss') 255 axes[0].set_xlabel('Epoch') 256 axes[0].set_ylabel('Loss') 257 axes[0].legend() 258 axes[0].grid(True, alpha=0.3) 259 260 axes[1].plot(history['train_acc'], label='Training Accuracy', linewidth=2, color=sns.color_palette("inferno", 10)[0]) 261 axes[1].plot(history['val_acc'], label='Validation Accuracy', linewidth=2, color=sns.color_palette("inferno", 10)[-1]) 262 axes[1].set_title('Model Accuracy') 263 axes[1].set_xlabel('Epoch') 264 axes[1].set_ylabel('Accuracy (%)') 265 axes[1].legend() 266 axes[1].grid(True, alpha=0.3) 267 268 plt.tight_layout() 269 return fig 270 271def plot_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, class_names: list[str] = ['No Hepatitis C', 'Hepatitis C'], use_percentages: bool = True) -> plt.Figure: 272 """ 273 Plot confusion matrix with option for percentages or absolute values. 274 275 Parameters 276 ----------- 277 y_true: np.ndarray 278 Ground truth labels. 279 y_pred: np.ndarray 280 Predicted labels 281 class_names: string 282 List of class names for labels 283 use_percentages: bool 284 If True, show percentages by true class; if False, show absolute counts 285 286 Returns 287 -------- 288 plt.Figure 289 Matplotlib figure object containing the confusion matrix plot. 290 """ 291 cm = confusion_matrix(y_true, y_pred) 292 293 plt.figure(figsize=(8, 6)) 294 295 if use_percentages: 296 # Calculate percentage confusion matrix (row-wise normalization) 297 cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100 298 sns.heatmap(cm_percent, annot=True, fmt='.1f', cmap='inferno', 299 xticklabels=class_names, yticklabels=class_names) 300 plt.title('Confusion Matrix (Percentages)', fontsize=16) 301 else: 302 sns.heatmap(cm, annot=True, fmt='d', cmap='inferno', 303 xticklabels=class_names, yticklabels=class_names) 304 plt.title('Confusion Matrix (Counts)', fontsize=16) 305 306 plt.xlabel('Predicted', fontsize=12) 307 plt.ylabel('Actual', fontsize=12) 308 309 # Add accuracy information 310 accuracy = np.trace(cm) / np.sum(cm) 311 plt.figtext(0.1, 0.02, f'Overall Accuracy: {accuracy:.3f}', fontsize=12) 312 313 plt.tight_layout() 314 return plt.gcf() 315 316def plot_roc_curve(y_true: np.ndarray, y_probs: np.ndarray) -> plt.Figure: 317 """ 318 Plot ROC curve with AUC. 319 320 Parameters 321 ----------- 322 y_true: np.ndarray 323 Ground truth binary labels. 324 y_probs: np.ndarray 325 Predicted probabilities for the positive class. 326 327 Returns 328 -------- 329 plt.Figure 330 Matplotlib figure object containing the ROC curve plot. 331 332 Examples 333 --------- 334 >>> fig = plot_roc_curve(y_true, y_probs) 335 >>> fig.show() 336 """ 337 338 fpr, tpr, _ = roc_curve(y_true, y_probs[:, 1]) 339 roc_auc = auc(fpr, tpr) 340 341 plt.figure(figsize=(8, 6)) 342 plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC Curve (AUC = {roc_auc:.3f})') 343 plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random') 344 plt.xlim([0.0, 1.0]) 345 plt.ylim([0.0, 1.05]) 346 plt.xlabel('False Positive Rate') 347 plt.ylabel('True Positive Rate') 348 plt.title('Receiver Operating Characteristic (ROC) Curve') 349 plt.legend(loc="lower right") 350 plt.grid(True, alpha=0.3) 351 plt.tight_layout() 352 return plt.gcf() 353 354def plot_precision_recall_curve(y_true: np.ndarray, y_probs: np.ndarray) -> plt.Figure: 355 """ 356 Plot precision-recall curve with AUC. 357 358 Parameters 359 ----------- 360 y_true: np.ndarray 361 Ground truth binary labels. 362 y_probs: np.ndarray 363 Predicted probabilities for the positive class. 364 365 Returns 366 -------- 367 plt.Figure 368 Matplotlib figure object containing the precision-recall curve plot. 369 370 Examples 371 --------- 372 >>> fig = plot_precision_recall_curve(y_true, y_probs) 373 >>> fig.show() 374 """ 375 precision, recall, _ = precision_recall_curve(y_true, y_probs[:, 1]) 376 pr_auc = auc(recall, precision) 377 378 plt.figure(figsize=(8, 6)) 379 plt.plot(recall, precision, color='blue', lw=2, label=f'PR Curve (AUC = {pr_auc:.3f})') 380 plt.xlim([0.0, 1.0]) 381 plt.ylim([0.0, 1.05]) 382 plt.xlabel('Recall') 383 plt.ylabel('Precision') 384 plt.title('Precision-Recall Curve') 385 plt.legend(loc="lower left") 386 plt.grid(True, alpha=0.3) 387 plt.tight_layout() 388 return plt.gcf() 389 390def plot_prediction_confidence(y_true: np.ndarray, y_probs: np.ndarray, class_names: list[str] = ['Healthy', 'Hepatitis C']) -> plt.Figure: 391 """ 392 Plot histograms of prediction confidence for each class, separated by predicted and true labels. 393 394 Parameters 395 ----------- 396 y_true: np.ndarray 397 Ground truth binary labels. 398 y_probs: np.ndarray 399 Predicted probabilities for each class (shape: [n_samples, n_classes]). 400 class_names: list of str 401 Names of the classes for labeling. 402 403 Returns 404 -------- 405 plt.Figure 406 Matplotlib figure object containing the prediction confidence histograms. 407 408 Examples 409 --------- 410 >>> fig = plot_prediction_confidence(y_true, y_probs) 411 >>> fig.show() 412 """ 413 414 fig, axes = plt.subplots(1, 2, figsize=(15, 5)) 415 y_pred = np.argmax(y_probs, axis=1) 416 max_probs = np.max(y_probs, axis=1) 417 418 for class_idx in [0, 1]: 419 mask = y_pred == class_idx 420 if np.any(mask): 421 axes[0].hist(max_probs[mask], bins=20, alpha=0.7, 422 label=f'Predicted: {class_names[class_idx]}') 423 424 axes[0].set_title('Prediction Confidence by Predicted Class') 425 axes[0].set_xlabel('Confidence') 426 axes[0].set_ylabel('Frequency') 427 axes[0].legend() 428 axes[0].grid(True, alpha=0.3) 429 430 for class_idx in [0, 1]: 431 mask = y_true == class_idx 432 if np.any(mask): 433 axes[1].hist(max_probs[mask], bins=20, alpha=0.7, 434 label=f'True: {class_names[class_idx]}') 435 436 axes[1].set_title('Prediction Confidence by True Class') 437 axes[1].set_xlabel('Confidence') 438 axes[1].set_ylabel('Frequency') 439 axes[1].legend() 440 axes[1].grid(True, alpha=0.3) 441 442 plt.tight_layout() 443 return fig
12def plot_data_overview(df: pd.DataFrame) -> plt.Figure: 13 """ 14 Generate an overview of the dataset with key statistics and visualizations. \ 15 Includes a bar plot of disease categories, a histogram of age distribution \ 16 of patients, a bar plot of sex distribution, and missing value counts per column. 17 18 Parameters 19 ---------- 20 df: pd.DataFrame 21 DataFrame containing the data 22 23 Returns 24 ------- 25 fig: plt.Figure 26 Matplotlib figure object containing the overview plots. 27 """ 28 29 fig, axes = plt.subplots(2, 2, figsize=(15, 10)) 30 31 if 'Category' in df.columns: 32 category_counts = df['Category'].value_counts() 33 colors = sns.color_palette("inferno", len(category_counts)) 34 35 # Create a single stacked bar 36 bottom = 0 37 bar_width = 0.6 38 for i, (category, count) in enumerate(category_counts.items()): 39 axes[0,0].bar(['Total'], [count], bottom=bottom, color=colors[i], 40 label=f'{category}: {count}', width=bar_width) 41 bottom += count 42 43 axes[0,0].set_title('Disease Categories') 44 axes[0,0].set_xlabel('Population') 45 axes[0,0].set_ylabel('Count') 46 axes[0,0].legend(loc='lower left') 47 axes[0,0].tick_params(axis='x', rotation=0) 48 49 if 'Age' in df.columns: 50 df['Age'].hist(bins=20, ax=axes[0,1], color=sns.color_palette("inferno", 10)[0]) 51 axes[0,1].set_title('Age Distribution') 52 axes[0,1].set_xlabel('Age (years)') 53 axes[0,1].set_ylabel('Frequency') 54 55 if 'Sex' in df.columns: 56 sex_counts = df['Sex'].value_counts() 57 colors = sns.color_palette("inferno", len(sex_counts)) 58 59 # Create a single stacked bar 60 bottom = 0 61 bar_width = 0.6 62 for i, (sex, count) in enumerate(sex_counts.items()): 63 axes[1,0].bar(['Total'], [count], bottom=bottom, color=colors[i], 64 label=f'{sex}: {count}', width=bar_width) 65 bottom += count 66 67 axes[1,0].set_title('Sex Distribution') 68 axes[1,0].set_xlabel('Population') 69 axes[1,0].set_ylabel('Count') 70 axes[1,0].legend() 71 axes[1,0].tick_params(axis='x', rotation=0) 72 73 missing = df.isnull().sum() 74 missing = missing[missing > 0] 75 if len(missing) > 0: 76 missing.plot(kind='bar', ax=axes[1,1], color=sns.color_palette("inferno", len(missing))) 77 axes[1,1].set_title('Missing Values by Column') 78 axes[1,1].set_ylabel('Count') 79 axes[1,1].tick_params(axis='x', rotation=45) 80 axes[1,1].yaxis.set_major_locator(plt.MaxNLocator(integer=True)) 81 else: 82 axes[1,1].text(0.5, 0.5, 'No Missing Values', ha='center', va='center', transform=axes[1,1].transAxes) 83 axes[1,1].set_title('Missing Values') 84 85 plt.tight_layout() 86 return fig
Generate an overview of the dataset with key statistics and visualizations. Includes a bar plot of disease categories, a histogram of age distribution of patients, a bar plot of sex distribution, and missing value counts per column.
Parameters
- df (pd.DataFrame): DataFrame containing the data
Returns
- fig (plt.Figure): Matplotlib figure object containing the overview plots.
88def plot_correlation_matrix(df: pd.DataFrame) -> plt.Figure: 89 """ 90 Plot a clustered correlation matrix for numeric features in the DataFrame. 91 92 Parameters 93 ---------- 94 df: pd.DataFrame 95 DataFrame containing the data 96 97 Returns 98 ------- 99 plt.Figure 100 Matplotlib figure object containing the correlation matrix plot. 101 """ 102 103 from scipy.cluster.hierarchy import linkage, dendrogram 104 from scipy.spatial.distance import squareform 105 106 numeric_cols = df.select_dtypes(include=[np.number]).columns 107 108 if len(numeric_cols) > 1: 109 plt.figure(figsize=(12, 10)) 110 correlation_matrix = df[numeric_cols].corr() 111 112 # Perform hierarchical clustering on the correlation matrix 113 # Convert correlation to distance (1 - |correlation|) 114 distance_matrix = 1 - np.abs(correlation_matrix) 115 condensed_distances = squareform(distance_matrix, checks=False) 116 linkage_matrix = linkage(condensed_distances, method='average') 117 118 # Get the order from clustering 119 dendro = dendrogram(linkage_matrix, labels=correlation_matrix.columns, no_plot=True) 120 cluster_order = dendro['leaves'] 121 122 # Reorder the correlation matrix 123 ordered_corr = correlation_matrix.iloc[cluster_order, cluster_order] 124 125 # Create mask for upper triangle 126 mask = np.triu(np.ones_like(ordered_corr, dtype=bool)) 127 128 sns.heatmap(ordered_corr, mask=mask, annot=True, cmap='inferno', 129 center=0, square=True, fmt='.2f') 130 plt.title('Feature Correlation Matrix (Clustered)') 131 plt.tight_layout() 132 return plt.gcf() 133 else: 134 print("Not enough numeric columns for correlation matrix") 135 return None
Plot a clustered correlation matrix for numeric features in the DataFrame.
Parameters
- df (pd.DataFrame): DataFrame containing the data
Returns
- plt.Figure: Matplotlib figure object containing the correlation matrix plot.
137def plot_feature_distributions(df: pd.DataFrame, target_col: str = 'target') -> plt.Figure: 138 """ 139 Create histograms of feature distributions for each numeric feature, separated by target class. 140 141 Parameters 142 ---------- 143 df: pd.DataFrame 144 DataFrame containing the data 145 target_col: str 146 Name of the target column to separate classes. Default is 'target'. 147 148 Returns 149 ------- 150 plt.Figure 151 Matplotlib figure object containing the feature distribution histograms. 152 """ 153 154 numeric_cols = ['ALB', 'ALP', 'ALT', 'AST', 'BIL', 'CHE', 'CHOL', 'CREA', 'GGT', 'PROT'] 155 available_cols = [col for col in numeric_cols if col in df.columns] 156 157 if not available_cols: 158 print("No feature columns found") 159 return None 160 161 n_cols = 5 162 n_rows = (len(available_cols) + n_cols - 1) // n_cols 163 164 fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 4*n_rows)) 165 axes = axes.flatten() if n_rows > 1 else [axes] if n_rows == 1 else axes 166 167 for i, feature in enumerate(available_cols): 168 if i < len(axes): 169 if target_col in df.columns: 170 for target_value in df[target_col].unique(): 171 subset = df[df[target_col] == target_value][feature] 172 label = 'Healthy' if target_value == 0 else 'Hepatitis C' 173 axes[i].hist(subset, alpha=0.7, label=label, bins=20, color=sns.color_palette("inferno", 2)[target_value]) 174 axes[i].set_title(f'{feature} Distribution') 175 axes[i].set_xlabel(feature) 176 axes[i].set_ylabel('Frequency') 177 axes[i].legend() 178 else: 179 df[feature].hist(bins=20, ax=axes[i]) 180 axes[i].set_title(f'{feature} Distribution') 181 182 for i in range(len(available_cols), len(axes)): 183 axes[i].set_visible(False) 184 185 plt.tight_layout() 186 return fig
Create histograms of feature distributions for each numeric feature, separated by target class.
Parameters
- df (pd.DataFrame): DataFrame containing the data
- target_col (str): Name of the target column to separate classes. Default is 'target'.
Returns
- plt.Figure: Matplotlib figure object containing the feature distribution histograms.
188def plot_violin_with_outliers(df, numeric_cols=None): 189 """ 190 Create violin plots with overlaid box plots to show outliers for each numeric feature. 191 192 Parameters 193 ---------- 194 df: pd.DataFrame 195 DataFrame containing the data 196 numeric_cols: list[str] 197 List of column names to plot. If None, uses default hepatitis C features. 198 199 Returns 200 ------- 201 plt.Figure 202 Matplotlib figure object containing the violin plots with outliers. 203 """ 204 if numeric_cols is None: 205 numeric_cols = ['Age','ALB','ALP','ALT','AST','BIL','CHE','CHOL','CREA','GGT','PROT'] 206 207 # Filter columns that exist in the dataframe 208 available_cols = [col for col in numeric_cols if col in df.columns] 209 210 if not available_cols: 211 print("No numeric columns found") 212 return None 213 214 # Create subplots arranged horizontally 215 fig, axes = plt.subplots(1, len(available_cols), figsize=(20, 6)) 216 colors = sns.color_palette("inferno", len(available_cols)) 217 218 # Handle case where there's only one column 219 if len(available_cols) == 1: 220 axes = [axes] 221 222 for i, col in enumerate(available_cols): 223 # Create violin plot 224 sns.violinplot(y=df[col], ax=axes[i], color=colors[i], alpha=0.7) 225 # Add box plot to show outliers and quartiles 226 sns.boxplot(y=df[col], ax=axes[i], width=0.3, boxprops={'facecolor':'None'}, 227 showfliers=True, flierprops={'marker':'o', 'markersize':3, 'markerfacecolor':'red'}) 228 axes[i].set_title(col, fontsize=12) 229 axes[i].set_xlabel('') 230 axes[i].set_ylabel('Values' if i == 0 else '') 231 232 plt.suptitle("Violin Plots with Outliers for Each Feature", fontsize=16, y=1.02) 233 plt.tight_layout() 234 return fig
Create violin plots with overlaid box plots to show outliers for each numeric feature.
Parameters
- df (pd.DataFrame): DataFrame containing the data
- numeric_cols (list[str]): List of column names to plot. If None, uses default hepatitis C features.
Returns
- plt.Figure: Matplotlib figure object containing the violin plots with outliers.
236def plot_training_history(history: dict) -> plt.Figure: 237 """ 238 Plot training and validation loss and accuracy over epochs. 239 240 Parameters 241 ---------- 242 history: dict 243 Dictionary containing training history with keys 'train_loss', 'val_loss', 'train_acc', 'val_acc'. 244 245 Returns 246 ------- 247 plt.Figure 248 Matplotlib figure object containing the training history plots. 249 """ 250 251 fig, axes = plt.subplots(1, 2, figsize=(15, 5)) 252 253 axes[0].plot(history['train_loss'], label='Training Loss', linewidth=2, color=sns.color_palette("inferno", 10)[0]) 254 axes[0].plot(history['val_loss'], label='Validation Loss', linewidth=2, color=sns.color_palette("inferno", 10)[-1]) 255 axes[0].set_title('Model Loss') 256 axes[0].set_xlabel('Epoch') 257 axes[0].set_ylabel('Loss') 258 axes[0].legend() 259 axes[0].grid(True, alpha=0.3) 260 261 axes[1].plot(history['train_acc'], label='Training Accuracy', linewidth=2, color=sns.color_palette("inferno", 10)[0]) 262 axes[1].plot(history['val_acc'], label='Validation Accuracy', linewidth=2, color=sns.color_palette("inferno", 10)[-1]) 263 axes[1].set_title('Model Accuracy') 264 axes[1].set_xlabel('Epoch') 265 axes[1].set_ylabel('Accuracy (%)') 266 axes[1].legend() 267 axes[1].grid(True, alpha=0.3) 268 269 plt.tight_layout() 270 return fig
Plot training and validation loss and accuracy over epochs.
Parameters
- history (dict): Dictionary containing training history with keys 'train_loss', 'val_loss', 'train_acc', 'val_acc'.
Returns
- plt.Figure: Matplotlib figure object containing the training history plots.
272def plot_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, class_names: list[str] = ['No Hepatitis C', 'Hepatitis C'], use_percentages: bool = True) -> plt.Figure: 273 """ 274 Plot confusion matrix with option for percentages or absolute values. 275 276 Parameters 277 ----------- 278 y_true: np.ndarray 279 Ground truth labels. 280 y_pred: np.ndarray 281 Predicted labels 282 class_names: string 283 List of class names for labels 284 use_percentages: bool 285 If True, show percentages by true class; if False, show absolute counts 286 287 Returns 288 -------- 289 plt.Figure 290 Matplotlib figure object containing the confusion matrix plot. 291 """ 292 cm = confusion_matrix(y_true, y_pred) 293 294 plt.figure(figsize=(8, 6)) 295 296 if use_percentages: 297 # Calculate percentage confusion matrix (row-wise normalization) 298 cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100 299 sns.heatmap(cm_percent, annot=True, fmt='.1f', cmap='inferno', 300 xticklabels=class_names, yticklabels=class_names) 301 plt.title('Confusion Matrix (Percentages)', fontsize=16) 302 else: 303 sns.heatmap(cm, annot=True, fmt='d', cmap='inferno', 304 xticklabels=class_names, yticklabels=class_names) 305 plt.title('Confusion Matrix (Counts)', fontsize=16) 306 307 plt.xlabel('Predicted', fontsize=12) 308 plt.ylabel('Actual', fontsize=12) 309 310 # Add accuracy information 311 accuracy = np.trace(cm) / np.sum(cm) 312 plt.figtext(0.1, 0.02, f'Overall Accuracy: {accuracy:.3f}', fontsize=12) 313 314 plt.tight_layout() 315 return plt.gcf()
Plot confusion matrix with option for percentages or absolute values.
Parameters
- y_true (np.ndarray): Ground truth labels.
- y_pred (np.ndarray): Predicted labels
- class_names (string): List of class names for labels
- use_percentages (bool): If True, show percentages by true class; if False, show absolute counts
Returns
- plt.Figure: Matplotlib figure object containing the confusion matrix plot.
317def plot_roc_curve(y_true: np.ndarray, y_probs: np.ndarray) -> plt.Figure: 318 """ 319 Plot ROC curve with AUC. 320 321 Parameters 322 ----------- 323 y_true: np.ndarray 324 Ground truth binary labels. 325 y_probs: np.ndarray 326 Predicted probabilities for the positive class. 327 328 Returns 329 -------- 330 plt.Figure 331 Matplotlib figure object containing the ROC curve plot. 332 333 Examples 334 --------- 335 >>> fig = plot_roc_curve(y_true, y_probs) 336 >>> fig.show() 337 """ 338 339 fpr, tpr, _ = roc_curve(y_true, y_probs[:, 1]) 340 roc_auc = auc(fpr, tpr) 341 342 plt.figure(figsize=(8, 6)) 343 plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC Curve (AUC = {roc_auc:.3f})') 344 plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random') 345 plt.xlim([0.0, 1.0]) 346 plt.ylim([0.0, 1.05]) 347 plt.xlabel('False Positive Rate') 348 plt.ylabel('True Positive Rate') 349 plt.title('Receiver Operating Characteristic (ROC) Curve') 350 plt.legend(loc="lower right") 351 plt.grid(True, alpha=0.3) 352 plt.tight_layout() 353 return plt.gcf()
Plot ROC curve with AUC.
Parameters
- y_true (np.ndarray): Ground truth binary labels.
- y_probs (np.ndarray): Predicted probabilities for the positive class.
Returns
- plt.Figure: Matplotlib figure object containing the ROC curve plot.
Examples
>>> fig = plot_roc_curve(y_true, y_probs)
>>> fig.show()
355def plot_precision_recall_curve(y_true: np.ndarray, y_probs: np.ndarray) -> plt.Figure: 356 """ 357 Plot precision-recall curve with AUC. 358 359 Parameters 360 ----------- 361 y_true: np.ndarray 362 Ground truth binary labels. 363 y_probs: np.ndarray 364 Predicted probabilities for the positive class. 365 366 Returns 367 -------- 368 plt.Figure 369 Matplotlib figure object containing the precision-recall curve plot. 370 371 Examples 372 --------- 373 >>> fig = plot_precision_recall_curve(y_true, y_probs) 374 >>> fig.show() 375 """ 376 precision, recall, _ = precision_recall_curve(y_true, y_probs[:, 1]) 377 pr_auc = auc(recall, precision) 378 379 plt.figure(figsize=(8, 6)) 380 plt.plot(recall, precision, color='blue', lw=2, label=f'PR Curve (AUC = {pr_auc:.3f})') 381 plt.xlim([0.0, 1.0]) 382 plt.ylim([0.0, 1.05]) 383 plt.xlabel('Recall') 384 plt.ylabel('Precision') 385 plt.title('Precision-Recall Curve') 386 plt.legend(loc="lower left") 387 plt.grid(True, alpha=0.3) 388 plt.tight_layout() 389 return plt.gcf()
Plot precision-recall curve with AUC.
Parameters
- y_true (np.ndarray): Ground truth binary labels.
- y_probs (np.ndarray): Predicted probabilities for the positive class.
Returns
- plt.Figure: Matplotlib figure object containing the precision-recall curve plot.
Examples
>>> fig = plot_precision_recall_curve(y_true, y_probs)
>>> fig.show()
391def plot_prediction_confidence(y_true: np.ndarray, y_probs: np.ndarray, class_names: list[str] = ['Healthy', 'Hepatitis C']) -> plt.Figure: 392 """ 393 Plot histograms of prediction confidence for each class, separated by predicted and true labels. 394 395 Parameters 396 ----------- 397 y_true: np.ndarray 398 Ground truth binary labels. 399 y_probs: np.ndarray 400 Predicted probabilities for each class (shape: [n_samples, n_classes]). 401 class_names: list of str 402 Names of the classes for labeling. 403 404 Returns 405 -------- 406 plt.Figure 407 Matplotlib figure object containing the prediction confidence histograms. 408 409 Examples 410 --------- 411 >>> fig = plot_prediction_confidence(y_true, y_probs) 412 >>> fig.show() 413 """ 414 415 fig, axes = plt.subplots(1, 2, figsize=(15, 5)) 416 y_pred = np.argmax(y_probs, axis=1) 417 max_probs = np.max(y_probs, axis=1) 418 419 for class_idx in [0, 1]: 420 mask = y_pred == class_idx 421 if np.any(mask): 422 axes[0].hist(max_probs[mask], bins=20, alpha=0.7, 423 label=f'Predicted: {class_names[class_idx]}') 424 425 axes[0].set_title('Prediction Confidence by Predicted Class') 426 axes[0].set_xlabel('Confidence') 427 axes[0].set_ylabel('Frequency') 428 axes[0].legend() 429 axes[0].grid(True, alpha=0.3) 430 431 for class_idx in [0, 1]: 432 mask = y_true == class_idx 433 if np.any(mask): 434 axes[1].hist(max_probs[mask], bins=20, alpha=0.7, 435 label=f'True: {class_names[class_idx]}') 436 437 axes[1].set_title('Prediction Confidence by True Class') 438 axes[1].set_xlabel('Confidence') 439 axes[1].set_ylabel('Frequency') 440 axes[1].legend() 441 axes[1].grid(True, alpha=0.3) 442 443 plt.tight_layout() 444 return fig
Plot histograms of prediction confidence for each class, separated by predicted and true labels.
Parameters
- y_true (np.ndarray): Ground truth binary labels.
- y_probs (np.ndarray): Predicted probabilities for each class (shape: [n_samples, n_classes]).
- class_names (list of str): Names of the classes for labeling.
Returns
- plt.Figure: Matplotlib figure object containing the prediction confidence histograms.
Examples
>>> fig = plot_prediction_confidence(y_true, y_probs)
>>> fig.show()