from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
from sklearn.linear_model import LogisticRegression as LR_sklearn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
roc_auc_score, roc_curve, confusion_matrix as cm_func,
accuracy_score, f1_score, precision_score, recall_score,
brier_score_loss, matthews_corrcoef
)
from collections import OrderedDict
try:
from imblearn.over_sampling import SMOTE, ADASYN, RandomOverSampler
from imblearn.under_sampling import RandomUnderSampler
from imblearn.combine import SMOTETomek
IMBLEARN_AVAILABLE = True
except ImportError:
IMBLEARN_AVAILABLE = False
# Feature columns: age, sex, service dummies, ALL comorbidities
ml_comorbidity_cols = [c for c in df_head_neck.columns if c in comorbidity_cols]
ALL_BALANCING_NAMES = ['None (Original)', 'SMOTE', 'ADASYN', 'SMOTETomek',
'Random Oversampling', 'Random Undersampling']
def get_balancing_methods(y_train):
"""Return ALL class balancing methods, adapting k_neighbors to minority class size."""
methods = OrderedDict()
methods['None (Original)'] = None
if not IMBLEARN_AVAILABLE:
for name in ALL_BALANCING_NAMES[1:]:
methods[name] = 'imblearn not installed'
return methods
n_minority = int(y_train.sum())
k = min(5, n_minority - 1) if n_minority > 1 else 0
if k >= 1:
methods['SMOTE'] = SMOTE(random_state=42, k_neighbors=k)
methods['ADASYN'] = ADASYN(random_state=42, n_neighbors=min(k, n_minority - 1))
methods['SMOTETomek'] = SMOTETomek(random_state=42, smote=SMOTE(random_state=42, k_neighbors=k))
else:
methods['SMOTE'] = f'Insufficient minority samples (n={n_minority}, need k>=1)'
methods['ADASYN'] = f'Insufficient minority samples (n={n_minority}, need k>=1)'
methods['SMOTETomek'] = f'Insufficient minority samples (n={n_minority}, need k>=1)'
methods['Random Oversampling'] = RandomOverSampler(random_state=42)
methods['Random Undersampling'] = RandomUnderSampler(random_state=42)
return methods
def _failed_balancing_entry(reason):
"""Return a placeholder result dict for a failed balancing technique."""
return {
'model': None,
'y_pred': None,
'y_proba': None,
'y_test': None,
'opt_threshold': np.nan,
'cm': None,
'failed': True,
'fail_reason': reason,
'metrics': {
'AUC': np.nan, 'Bal. Accuracy': np.nan, 'Sensitivity': np.nan,
'Specificity': np.nan, 'PPV': np.nan, 'NPV': np.nan,
'F1': np.nan, 'MCC': np.nan, 'Brier': np.nan,
"Youden's J": np.nan, 'Threshold': np.nan,
'Train Size': '—', 'Minority %': '—'
}
}
def run_ml_mortality_multi_balancing(df_site, site_name):
"""Train ML classifiers with multiple class balancing techniques + optimal threshold"""
df_ml = df_site.copy()
df_ml['sex_female'] = (df_ml['SEXO'] == 'Female').astype(int)
df_ml['edad'] = pd.to_numeric(df_ml['edad'], errors='coerce')
svc_dummies = pd.get_dummies(df_ml['service_category'], prefix='service', drop_first=True)
for c in svc_dummies.columns:
df_ml[c] = svc_dummies[c]
feature_names = ['edad', 'sex_female'] + list(svc_dummies.columns) + ml_comorbidity_cols
target = 'egresar_fallecido'
df_ml = df_ml[feature_names + [target]].dropna()
n_pos = int(df_ml[target].sum())
n_neg = int((df_ml[target] == 0).sum())
if len(df_ml) < 30 or n_pos < 5 or n_neg < 5:
return None, None, None
X = df_ml[feature_names].values
y = df_ml[target].values
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
try:
X_train, X_test, y_train, y_test = train_test_split(
X_scaled, y, test_size=0.2, random_state=42, stratify=y
)
except ValueError:
X_train, X_test, y_train, y_test = train_test_split(
X_scaled, y, test_size=0.2, random_state=42
)
balancing_methods = get_balancing_methods(y_train)
base_model_class = GradientBoostingClassifier
base_model_params = {'n_estimators': 150, 'random_state': 42, 'max_depth': 4}
balancing_results = OrderedDict()
for bal_name, balancer in balancing_methods.items():
if isinstance(balancer, str):
balancing_results[bal_name] = _failed_balancing_entry(balancer)
continue
try:
if balancer is not None:
X_res, y_res = balancer.fit_resample(X_train, y_train)
else:
X_res, y_res = X_train, y_train
model = base_model_class(**base_model_params)
model.fit(X_res, y_res)
y_proba = model.predict_proba(X_test)[:, 1]
fpr_t, tpr_t, thresholds_t = roc_curve(y_test, y_proba)
j_scores = tpr_t - fpr_t
best_idx = np.argmax(j_scores)
opt_threshold = thresholds_t[best_idx]
y_pred = (y_proba >= opt_threshold).astype(int)
tn, fp, fn, tp = cm_func(y_test, y_pred).ravel()
sens = tp / (tp + fn) if (tp + fn) > 0 else 0
spec = tn / (tn + fp) if (tn + fp) > 0 else 0
ppv_val = tp / (tp + fp) if (tp + fp) > 0 else 0
npv_val = tn / (tn + fn) if (tn + fn) > 0 else 0
balancing_results[bal_name] = {
'model': model,
'y_pred': y_pred,
'y_proba': y_proba,
'y_test': y_test,
'opt_threshold': opt_threshold,
'cm': cm_func(y_test, y_pred),
'failed': False,
'metrics': {
'AUC': roc_auc_score(y_test, y_proba) if len(np.unique(y_test)) > 1 else 0,
'Bal. Accuracy': (sens + spec) / 2,
'Sensitivity': sens,
'Specificity': spec,
'PPV': ppv_val,
'NPV': npv_val,
'F1': f1_score(y_test, y_pred, zero_division=0),
'MCC': matthews_corrcoef(y_test, y_pred),
'Brier': brier_score_loss(y_test, y_proba),
"Youden's J": sens + spec - 1,
'Threshold': opt_threshold,
'Train Size': len(y_res),
'Minority %': f'{y_res.sum()/len(y_res)*100:.1f}'
}
}
except Exception as e:
balancing_results[bal_name] = _failed_balancing_entry(str(e))
successful_results = {k: v for k, v in balancing_results.items() if not v.get('failed', False)}
if not successful_results:
return balancing_results, None, None
best_balancing = max(successful_results, key=lambda k: successful_results[k]['metrics']['AUC'])
best_balancer = balancing_methods.get(best_balancing)
if best_balancer is not None:
try:
X_best, y_best = best_balancer.fit_resample(X_train, y_train)
except:
X_best, y_best = X_train, y_train
else:
X_best, y_best = X_train, y_train
all_models = {
'Logistic Regression': LR_sklearn(max_iter=1000, random_state=42, class_weight='balanced'),
'Random Forest': RandomForestClassifier(n_estimators=200, random_state=42, n_jobs=-1, class_weight='balanced'),
'Gradient Boosting': GradientBoostingClassifier(n_estimators=150, random_state=42, max_depth=4),
'SVM': SVC(kernel='rbf', probability=True, random_state=42, class_weight='balanced'),
'MLP': MLPClassifier(hidden_layer_sizes=(64, 32), max_iter=1000, random_state=42)
}
model_results = {}
for name, model in all_models.items():
try:
model.fit(X_best, y_best)
y_proba = model.predict_proba(X_test)[:, 1]
fpr_t, tpr_t, thresholds_t = roc_curve(y_test, y_proba)
j_scores = tpr_t - fpr_t
best_idx = np.argmax(j_scores)
opt_threshold = thresholds_t[best_idx]
y_pred = (y_proba >= opt_threshold).astype(int)
tn, fp, fn, tp = cm_func(y_test, y_pred).ravel()
sens = tp / (tp + fn) if (tp + fn) > 0 else 0
spec = tn / (tn + fp) if (tn + fp) > 0 else 0
ppv_val = tp / (tp + fp) if (tp + fp) > 0 else 0
npv_val = tn / (tn + fn) if (tn + fn) > 0 else 0
model_results[name] = {
'model': model,
'y_pred': y_pred,
'y_proba': y_proba,
'y_test': y_test,
'opt_threshold': opt_threshold,
'cm': cm_func(y_test, y_pred),
'metrics': {
'AUC': roc_auc_score(y_test, y_proba) if len(np.unique(y_test)) > 1 else 0,
'Bal. Accuracy': (sens + spec) / 2,
'Sensitivity': sens,
'Specificity': spec,
'PPV': ppv_val,
'NPV': npv_val,
'F1': f1_score(y_test, y_pred, zero_division=0),
'MCC': matthews_corrcoef(y_test, y_pred),
'Brier': brier_score_loss(y_test, y_proba),
"Youden's J": sens + spec - 1,
'Threshold': opt_threshold
}
}
except:
pass
return balancing_results, model_results, (feature_names, X_test, y_test, best_balancing)
# Run ML mortality models for EACH cancer site
ml_balancing_by_site = {}
ml_models_by_site = {}
ml_meta_by_site = {}
for site in all_sites:
df_site = df_head_neck[df_head_neck['cancer_site_group'] == site].copy()
bal_results, model_results, meta = run_ml_mortality_multi_balancing(df_site, site)
if bal_results is not None and len(bal_results) > 0:
ml_balancing_by_site[site] = bal_results
if model_results is not None:
ml_models_by_site[site] = model_results
if meta is not None:
ml_meta_by_site[site] = meta