Loading¶
In [29]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from mpl_toolkits.basemap import Basemap
from sklearn.impute import KNNImputer
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OrdinalEncoder, StandardScaler
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from imblearn.over_sampling import SMOTE
from imblearn.combine import SMOTETomek
import xgboost as xgb
from sklearn.metrics import roc_curve,roc_auc_score
#load_df = pd.read_csv("Killed_and_Seriously_Injured.csv")
load_df = pd.read_csv("allfilter_injury_data2.csv")
load_df.head()
Out[29]:
X | Y | OBJECTID | INDEX_ | ACCNUM | DATE | TIME | STREET1 | STREET2 | OFFSET | ... | SPEEDING | AG_DRIV | REDLIGHT | ALCOHOL | DISABILITY | HOOD_158 | NEIGHBOURHOOD_158 | HOOD_140 | NEIGHBOURHOOD_140 | DIVISION | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 642702.4974 | 4.855938e+06 | 20 | 3363207 | 882024.0 | 2006/01/07 10:00:00+00 | 2325 | STEELES AVE E | NINTH LINE ST | NaN | ... | NaN | NaN | NaN | NaN | NaN | 144 | Morningside Heights | 131 | Rouge (131) | D42 |
1 | 616144.1868 | 4.841944e+06 | 32 | 3363869 | 882497.0 | 2006/01/08 10:00:00+00 | 1828 | ISLINGTON AVE | GOLFDOWN DR | NaN | ... | NaN | Yes | NaN | NaN | NaN | 5 | Elms-Old Rexdale | 5 | Elms-Old Rexdale (5) | D23 |
2 | 638249.2383 | 4.847699e+06 | 35 | 3363416 | 882174.0 | 2006/01/09 10:00:00+00 | 1435 | KENNEDY RD | GLAMORGAN AVE | NaN | ... | NaN | NaN | NaN | NaN | NaN | 126 | Dorset Park | 126 | Dorset Park (126) | D41 |
3 | 636288.2909 | 4.842392e+06 | 43 | 3363879 | 882501.0 | 2006/01/11 10:00:00+00 | 1120 | BARTLEY DR | JINNAH CRT | NaN | ... | Yes | Yes | NaN | NaN | NaN | 43 | Victoria Village | 43 | Victoria Village (43) | D55 |
4 | 638765.5901 | 4.848810e+06 | 63 | 3371161 | 886230.0 | 2006/01/21 10:00:00+00 | 1829 | MIDLAND AVE | GOODLAND GT | NaN | ... | Yes | Yes | NaN | NaN | NaN | 128 | Agincourt South-Malvern West | 128 | Agincourt South-Malvern West (128) | D42 |
5 rows × 54 columns
ETL: persons to incidents¶
In [30]:
# fatal_rows = (load_df['ACCLASS'] == 'Fatal') & (load_df['INJURY'] == 'Fatal')
# df_fatal = load_df.loc[fatal_rows]
# # df_fatal = df_fatal.drop_duplicates(subset=['ACCNUM'])
# no_fatal_row = (load_df['ACCLASS'] == 'Non-Fatal Injury')
# df_non_fatal = load_df.loc[no_fatal_row]
# df_non_fatal = df_non_fatal.drop_duplicates(subset=['ACCNUM'])
# df_final = pd.concat([df_fatal, df_non_fatal], ignore_index=True)
# df_final.to_csv('allfilter_injury_data2.csv', index=False)
EDA: exploring data initially¶
In [31]:
load_df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 5299 entries, 0 to 5298 Data columns (total 54 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 X 5299 non-null float64 1 Y 5299 non-null float64 2 OBJECTID 5299 non-null int64 3 INDEX_ 5299 non-null int64 4 ACCNUM 4962 non-null float64 5 DATE 5299 non-null object 6 TIME 5299 non-null int64 7 STREET1 5299 non-null object 8 STREET2 4804 non-null object 9 OFFSET 732 non-null object 10 ROAD_CLASS 5155 non-null object 11 DISTRICT 5208 non-null object 12 LATITUDE 5299 non-null float64 13 LONGITUDE 5299 non-null float64 14 ACCLOC 3452 non-null object 15 TRAFFCTL 5269 non-null object 16 VISIBILITY 5287 non-null object 17 LIGHT 5297 non-null object 18 RDSFCOND 5286 non-null object 19 ACCLASS 5299 non-null object 20 IMPACTYPE 5290 non-null object 21 INVTYPE 5296 non-null object 22 INVAGE 5299 non-null object 23 INJURY 2487 non-null object 24 FATAL_NO 864 non-null float64 25 INITDIR 4000 non-null object 26 VEHTYPE 4816 non-null object 27 MANOEUVER 3448 non-null object 28 DRIVACT 3194 non-null object 29 DRIVCOND 3192 non-null object 30 PEDTYPE 672 non-null object 31 PEDACT 673 non-null object 32 PEDCOND 667 non-null object 33 CYCLISTYPE 109 non-null object 34 CYCACT 114 non-null object 35 CYCCOND 113 non-null object 36 PEDESTRIAN 2402 non-null object 37 CYCLIST 623 non-null object 38 AUTOMOBILE 4691 non-null object 39 MOTORCYCLE 530 non-null object 40 TRUCK 296 non-null object 41 TRSN_CITY_VEH 282 non-null object 42 EMERG_VEH 7 non-null object 43 PASSENGER 1223 non-null object 44 SPEEDING 660 non-null object 45 AG_DRIV 2547 non-null object 46 REDLIGHT 350 non-null object 47 ALCOHOL 208 non-null object 48 DISABILITY 145 non-null object 49 HOOD_158 5299 non-null object 50 NEIGHBOURHOOD_158 5299 non-null object 51 HOOD_140 5299 non-null object 52 NEIGHBOURHOOD_140 5299 non-null object 53 DIVISION 5299 non-null object dtypes: float64(6), int64(3), object(45) memory usage: 2.2+ MB
In [32]:
print("\nMissing values:")
print(load_df.isnull().sum())
Missing values: X 0 Y 0 OBJECTID 0 INDEX_ 0 ACCNUM 337 DATE 0 TIME 0 STREET1 0 STREET2 495 OFFSET 4567 ROAD_CLASS 144 DISTRICT 91 LATITUDE 0 LONGITUDE 0 ACCLOC 1847 TRAFFCTL 30 VISIBILITY 12 LIGHT 2 RDSFCOND 13 ACCLASS 0 IMPACTYPE 9 INVTYPE 3 INVAGE 0 INJURY 2812 FATAL_NO 4435 INITDIR 1299 VEHTYPE 483 MANOEUVER 1851 DRIVACT 2105 DRIVCOND 2107 PEDTYPE 4627 PEDACT 4626 PEDCOND 4632 CYCLISTYPE 5190 CYCACT 5185 CYCCOND 5186 PEDESTRIAN 2897 CYCLIST 4676 AUTOMOBILE 608 MOTORCYCLE 4769 TRUCK 5003 TRSN_CITY_VEH 5017 EMERG_VEH 5292 PASSENGER 4076 SPEEDING 4639 AG_DRIV 2752 REDLIGHT 4949 ALCOHOL 5091 DISABILITY 5154 HOOD_158 0 NEIGHBOURHOOD_158 0 HOOD_140 0 NEIGHBOURHOOD_140 0 DIVISION 0 dtype: int64
Transfrom columns:¶
'TRAFFCTL', 'VISIBILITY', 'LIGHT', 'RDSFCOND', 'ACCLASS', 'IMPACTYPE', 'INVTYPE', 'INVAGE', 'VEHTYPE', 'ALCOHOL'
In [33]:
'''
1: 'Small Vehicles',
2: 'Trucks and Vans',
3: 'Public Transit',
4: 'Emergency and Unknown',
5: 'Special Equipment',
6: 'Off-Road',
7: 'Bicycles and Mopeds',
8: 'Motorcycles',
9: 'Rickshaws',
10: 'Others'
'''
load_df["VEHTYPE"] = load_df["VEHTYPE"].fillna('Other')
classification = {
'Automobile, Station Wagon': 1,
'Bicycle': 7,
'Motorcycle': 8,
'Pick Up Truck': 1,
'Passenger Van': 1,
'Taxi': 1,
'Moped': 7,
'Delivery Van': 2,
'Truck - Open': 2,
'Truck - Closed (Blazer, etc)': 2,
'Truck - Dump': 2,
'Truck-Tractor': 2,
'Truck (other)': 2,
'Truck - Tank': 2,
'Tow Truck': 2,
'Truck - Car Carrier': 2,
'Municipal Transit Bus (TTC)': 3,
'Street Car': 3,
'Bus (Other) (Go Bus, Gray Coa': 3,
'Intercity Bus': 3,
'School Bus': 3,
'Other': 10,
'Unknown': 4,
'Police Vehicle': 4,
'Fire Vehicle': 4,
'Other Emergency Vehicle': 4,
'Construction Equipment': 5,
'Rickshaw': 9,
'Ambulance': 4,
'Off Road - 2 Wheels': 6,
'Off Road - 4 Wheels': 6,
'Off Road - Other': 6
}
load_df['VEHTYPE'] = load_df['VEHTYPE'].map(classification)
load_df['VEHTYPE'].value_counts()
Out[33]:
VEHTYPE 1 2741 10 1923 8 351 7 127 2 96 3 54 4 5 6 1 9 1 Name: count, dtype: int64
In [34]:
'''
1Normal
2Impaired (includes inattentive, medical or physical disability, had been drinking, alcohol impairment, drug impairment)
3Other (includes other and fatigue)
'''
load_df["DRIVCOND"] = load_df["DRIVCOND"].fillna('Other')
drivcond_classification = {
'Normal': 1,
'Inattentive': 2,
'Unknown': 2,
'Medical or Physical Disability': 2,
'Had Been Drinking': 2,
'Ability Impaired, Alcohol Over .08': 2,
'Ability Impaired, Alcohol': 2,
'Other': 3,
'Fatigue': 3,
'Ability Impaired, Drugs': 2
}
load_df['DRIVCOND'] = load_df['DRIVCOND'].map(drivcond_classification)
load_df['DRIVCOND'].value_counts()
Out[34]:
DRIVCOND 3 2193 1 1757 2 1349 Name: count, dtype: int64
In [35]:
'''
1Infants and Young Children (0 to 9)
2Adolescents (10 to 19)
3Young Adults (20 to 34)
4Middle-Aged Adults (35 to 49)
5Older Adults (50 and above)
6Unknown
'''
load_df['INVAGE'].value_counts()
age_classification = {
'unknown': 6, # Category 6: Unknown
'0 to 4': 1, # Category 1: Infants and Young Children
'5 to 9': 1, # Category 1: Infants and Young Children
'10 to 14': 2, # Category 2: Adolescents
'15 to 19': 2, # Category 2: Adolescents
'20 to 24': 3, # Category 3: Young Adults
'25 to 29': 3, # Category 3: Young Adults
'30 to 34': 3, # Category 3: Young Adults
'35 to 39': 4, # Category 4: Middle-Aged Adults
'40 to 44': 4, # Category 4: Middle-Aged Adults
'45 to 49': 4, # Category 4: Middle-Aged Adults
'50 to 54': 5, # Category 5: Older Adults
'55 to 59': 5, # Category 5: Older Adults
'60 to 64': 5, # Category 5: Older Adults
'65 to 69': 5, # Category 5: Older Adults
'70 to 74': 5, # Category 5: Older Adults
'75 to 79': 5, # Category 5: Older Adults
'80 to 84': 5, # Category 5: Older Adults
'85 to 89': 5, # Category 5: Older Adults
'90 to 94': 5, # Category 5: Older Adults
'Over 95': 5 # Category 5: Older Adults
}
# Apply classification to the DataFrame
load_df['INVAGE'] = load_df['INVAGE'].map(age_classification)
load_df['INVAGE'].value_counts()
Out[35]:
INVAGE 5 1799 3 1214 4 1048 6 994 2 200 1 44 Name: count, dtype: int64
In [36]:
'''
1: No Control (e.g., 'No Control')
2: Traffic Control Devices (e.g., 'Traffic Signal', 'Stop Sign', 'Pedestrian Crossover', etc.)
3: Other (e.g., 'Traffic Gate', 'School Guard', 'Police Control')
'''
load_df["TRAFFCTL"] = load_df["TRAFFCTL"].fillna('No Control')
load_df['TRAFFCTL'].value_counts()
traffic_control_classification = {
'No Control': 1,
'Traffic Signal': 2,
'Stop Sign': 2,
'Pedestrian Crossover': 2,
'Traffic Controller': 2,
'Yield Sign': 2,
'Streetcar (Stop for)': 2,
'Traffic Gate': 3,
'School Guard': 3,
'Police Control': 3
}
load_df['TRAFFCTL'] = load_df['TRAFFCTL'].map(traffic_control_classification)
load_df['TRAFFCTL'].value_counts()
Out[36]:
TRAFFCTL 2 2668 1 2627 3 4 Name: count, dtype: int64
In [37]:
'''
1: Clear (e.g., 'Clear')
2: Adverse Weather (e.g., 'Rain', 'Snow', 'Fog, Mist, Smoke, Dust', etc.)
3: Severe Weather (e.g., 'Strong wind')
'''
load_df["VISIBILITY"] = load_df["VISIBILITY"].fillna('Clear')
load_df['VISIBILITY'].value_counts()
# Define the classification
visibility_classification = {
'Clear': 1,
'Rain': 2,
'Snow': 2,
'Other': 2,
'Fog, Mist, Smoke, Dust': 2,
'Freezing Rain': 2,
'Drifting Snow': 2,
'Strong wind': 3
}
# Apply classification to the DataFrame
load_df['VISIBILITY'] = load_df['VISIBILITY'].map(visibility_classification)
load_df['VISIBILITY'].value_counts()
Out[37]:
VISIBILITY 1 4530 2 766 3 3 Name: count, dtype: int64
In [38]:
'''
1: Daylight (e.g., 'Daylight', 'Daylight, artificial')
2: Artificial Light (e.g., 'Dark, artificial', 'Dusk, artificial', 'Dawn, artificial')
3: Low Light (e.g., 'Dark', 'Dusk', 'Dawn', 'Other')
'''
load_df["LIGHT"] = load_df["LIGHT"].fillna('Other')
load_df['LIGHT'].value_counts()
light_classification = {
'Daylight': 1,
'Daylight, artificial': 1,
'Dark': 3,
'Dark, artificial': 2,
'Dusk': 3,
'Dusk, artificial': 2,
'Dawn': 3,
'Dawn, artificial': 2,
'Other': 3
}
# Apply classification to the DataFrame
load_df['LIGHT'] = load_df['LIGHT'].map(light_classification)
load_df['LIGHT'].value_counts()
Out[38]:
LIGHT 1 3071 3 1318 2 910 Name: count, dtype: int64
In [39]:
'''
Dry (1)
Wet (2): Includes Wet and Spilled Liquid conditions.
Slushy/Other (3): Includes Slush and any other unspecified conditions.
Loose Surface (4): Includes Loose Snow, Packed Snow, and Loose Sand/Gravel.
Ice (5): Purely icy conditions.
'''
load_df["RDSFCOND"] = load_df["RDSFCOND"].fillna('Other')
load_df['RDSFCOND'].value_counts()
road_condition_classification = {
'Dry': 1, # Category 1: Dry
'Wet': 2, # Category 2: Wet
'Slush': 3, # Category 3: Slushy
'Loose Snow': 4, # Category 4: Loose Snow
'Packed Snow': 4, # Category 4: Packed Snow
'Ice': 5, # Category 5: Ice
'Loose Sand or Gravel': 4, # Category 4: Loose Sand/Gravel
'Spilled liquid': 2, # Category 2: Wet (Spilled Liquid)
'Other': 3 # Category 3: Slushy/Other
}
load_df['RDSFCOND'] = load_df['RDSFCOND'].map(road_condition_classification)
load_df['RDSFCOND'].value_counts()
Out[39]:
RDSFCOND 1 4201 2 921 3 95 4 63 5 19 Name: count, dtype: int64
In [40]:
'''
Drivers (1): Includes all types of drivers (e.g., Car Driver, Motorcycle Driver, Truck Driver).
Cyclists/Skaters (2): Includes Cyclists, Cyclist Passengers, and In-Line Skaters.
Passengers (3): Includes Car, Motorcycle, and Moped Passengers.
Pedestrians (4): Includes Pedestrians and those using Wheelchairs.
Vehicle & Property Owners (5): Includes Vehicle Owners and Other Property Owners.
Other/Special Cases (6): Includes Witnesses, Trailer Owners, and Other unspecified cases.
'''
load_df["INVTYPE"] = load_df["INVTYPE"].fillna('Other')
invtype_classification = {
'Driver': 1,
'Motorcycle Driver': 1,
'Truck Driver': 1,
'Moped Driver': 1,
'Driver - Not Hit': 1,
'Cyclist': 2,
'In-Line Skater': 2,
'Passenger': 3,
'Motorcycle Passenger': 3,
'Pedestrian': 4,
'Wheelchair': 4,
'Vehicle Owner': 5,
'Other Property Owner': 5,
'Other': 6
}
# Apply classification to the DataFrame
load_df['INVTYPE'] = load_df['INVTYPE'].map(invtype_classification)
load_df = load_df.dropna(subset=['INVTYPE'])
load_df['INVTYPE'].value_counts()
Out[40]:
INVTYPE 1 3253 5 759 4 676 3 463 2 118 6 30 Name: count, dtype: int64
In [41]:
'''
1: Collisions Involving Vulnerable Road Users (e.g., 'Pedestrian Collisions', 'Cyclist Collisions')
2: Vehicle-to-Vehicle Collisions (e.g., 'Turning Movement', 'Rear End', 'Angle', 'Sideswipe', 'Approaching')
3: Other (e.g., 'SMV Other', 'Other', 'SMV Unattended Vehicle')
'''
load_df["IMPACTYPE"] = load_df["IMPACTYPE"].fillna('Other')
load_df['IMPACTYPE'].value_counts()
impact_type_classification = {
'Pedestrian Collisions': 1,
'Cyclist Collisions': 1,
'Turning Movement': 2,
'Rear End': 2,
'SMV Other': 2,
'Angle': 2,
'Approaching': 2,
'Sideswipe': 2,
'Other': 3,
'SMV Unattended Vehicle': 3
}
# Apply classification to the DataFrame
load_df['IMPACTYPE'] = load_df['IMPACTYPE'].map(impact_type_classification)
load_df['IMPACTYPE'].value_counts()
Out[41]:
IMPACTYPE 1 2970 2 2210 3 119 Name: count, dtype: int64
In [42]:
load_df["ACCLASS"] = load_df["ACCLASS"].fillna('Non-Fatal Injury')
load_df['ACCLASS'].value_counts()
load_df["ACCLASS"] = (
load_df["ACCLASS"].map(
{"Non-Fatal Injury": 0,
"Fatal": 1,
"Property Damage O": 0
}
)
)
load_df["ACCLASS"].value_counts()
Out[42]:
ACCLASS 0 4325 1 974 Name: count, dtype: int64
map¶
In [43]:
def mapToronto(data_full):
# Coordinates for Toronto, Canada
llcrnrlat = 43.581024 # Lower left corner latitude
urcrnrlat = 43.855457 # Upper right corner latitude
llcrnrlon = -79.639219 # Lower left corner longitude
urcrnrlon = -79.115218 # Upper right corner longitude
# Initialize the Basemap
m = Basemap(projection='merc', llcrnrlat=llcrnrlat, urcrnrlat=urcrnrlat,
llcrnrlon=llcrnrlon, urcrnrlon=urcrnrlon, resolution='i')
# Draw map details
m.drawcountries()
m.drawparallels(np.arange(-90, 91., 2.), labels=[1,0,0,0])
m.drawmeridians(np.arange(-180, 181., 2.), labels=[0,0,0,1])
# Extract data from dataframe
lat = data_full['LATITUDE'].values
lon = data_full['LONGITUDE'].values
a_1 = data_full['ACCLASS'].values
# Plot data
m.scatter(lon, lat, latlon=True, c=a_1, s=50, linewidth=1, edgecolors='red', cmap='hot', alpha=1)
# Add color bar
cbar = m.colorbar()
cbar.set_label('Fatality Count')
# Add title
plt.title("Toronto, Canada Fatalities", fontsize=30)
plt.show()
# Set the style and size of the plot
sns.set(style="white", font_scale=1.5)
plt.figure(figsize=(20,20))
# Call the function to plot the map
mapToronto(load_df)
In [44]:
##Feature selection:delete alcohol
In [45]:
new_df = load_df[['TRAFFCTL', 'VISIBILITY', 'LIGHT', 'RDSFCOND','DRIVCOND', 'ACCLASS', 'IMPACTYPE', 'INVTYPE', 'INVAGE', 'VEHTYPE']]
print(new_df)
TRAFFCTL VISIBILITY LIGHT RDSFCOND DRIVCOND ACCLASS IMPACTYPE \ 0 1 1 3 2 2 1 2 1 2 1 3 1 3 1 1 2 2 1 1 1 3 1 1 3 1 1 1 2 3 1 2 4 1 1 3 1 3 1 1 ... ... ... ... ... ... ... ... 5294 1 2 2 2 1 0 2 5295 2 1 1 1 1 0 1 5296 2 1 2 1 2 0 2 5297 2 2 2 2 1 0 1 5298 1 2 3 2 1 0 1 INVTYPE INVAGE VEHTYPE 0 1 5 1 1 4 2 10 2 4 5 10 3 3 2 10 4 4 5 10 ... ... ... ... 5294 1 4 1 5295 1 5 1 5296 1 5 1 5297 1 5 1 5298 1 5 1 [5299 rows x 10 columns]
In [46]:
sns.heatmap(
new_df.corr(numeric_only=True),
vmin=-1,
vmax=1,
cmap="coolwarm"
)
Out[46]:
<Axes: >
imbalanced data¶
In [47]:
new_df["ACCLASS"].value_counts(normalize=True).plot.bar()
Out[47]:
<Axes: xlabel='ACCLASS'>
In [48]:
X = new_df.drop(columns=['ACCLASS'])
y = new_df['ACCLASS']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
Apply SMOTETomek to the training data¶
In [49]:
smote_tomek = SMOTETomek(random_state=42)
X_train_resampled, y_train_resampled = smote_tomek.fit_resample(X_train, y_train)
# Show SMOTETomek result
# Plot class distribution before SMOTETomek
plt.figure(figsize=(10, 5))
sns.countplot(x=y_train)
plt.title('Class Distribution Before SMOTETomek')
plt.xlabel('Class')
plt.ylabel('Count')
plt.show()
# Plot class distribution after SMOTETomek
plt.figure(figsize=(10, 5))
sns.countplot(x=y_train_resampled)
plt.title('Class Distribution After SMOTETomek')
plt.xlabel('Class')
plt.ylabel('Count')
plt.show()
Scale the features¶
In [50]:
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train_resampled)
X_test_scaled = scaler.transform(X_test)
Function to train and evaluate models¶
In [51]:
def train_and_evaluate(model, X_train, X_test, y_train, y_test, model_name):
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
# Calculate probabilities for ROC AUC
if hasattr(model, "predict_proba"):
y_pred_proba = model.predict_proba(X_test)[:, 1]
else:
y_pred_proba = model.decision_function(X_test)
fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
auc_score = roc_auc_score(y_test, y_pred_proba)
print(f"Results for {model_name}:")
print("Accuracy:", accuracy_score(y_test, y_pred))
print("Classification Report:\n", classification_report(y_test, y_pred))
print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred))
print("AUC Score:", auc_score)
print("\n" + "="*60 + "\n")
plt.figure()
plt.plot(fpr, tpr, label=f'ROC curve (area = {auc_score:.2f})')
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title(f'Receiver Operating Characteristic - {model_name}')
plt.legend(loc="lower right")
plt.show()
Define models¶
In [52]:
models = {
"Logistic Regression": LogisticRegression(max_iter=1000),
"Decision Tree": DecisionTreeClassifier(),
"Random Forest": RandomForestClassifier(n_estimators=100),
"Support Vector Machine": SVC(probability=True),
"Neural Network": MLPClassifier(hidden_layer_sizes=(100,), max_iter=300),
"XGBoost": xgb.XGBClassifier(use_label_encoder=False, eval_metric='mlogloss')
}
Train and evaluate each model¶
In [53]:
for name, model in models.items():
if name in ["Logistic Regression", "Support Vector Machine", "Neural Network"]:
train_and_evaluate(model, X_train_scaled, X_test_scaled, y_train_resampled, y_test, name)
else:
train_and_evaluate(model, X_train_resampled, X_test, y_train_resampled, y_test, name)
Results for Logistic Regression: Accuracy: 0.7688679245283019 Classification Report: precision recall f1-score support 0 0.95 0.76 0.84 865 1 0.43 0.81 0.56 195 accuracy 0.77 1060 macro avg 0.69 0.78 0.70 1060 weighted avg 0.85 0.77 0.79 1060 Confusion Matrix: [[657 208] [ 37 158]] AUC Score: 0.8238002074996295 ============================================================
Results for Decision Tree: Accuracy: 0.8188679245283019 Classification Report: precision recall f1-score support 0 0.95 0.82 0.88 865 1 0.50 0.82 0.62 195 accuracy 0.82 1060 macro avg 0.73 0.82 0.75 1060 weighted avg 0.87 0.82 0.83 1060 Confusion Matrix: [[709 156] [ 36 159]] AUC Score: 0.8935111901585889 ============================================================
Results for Random Forest: Accuracy: 0.8235849056603773 Classification Report: precision recall f1-score support 0 0.95 0.83 0.88 865 1 0.51 0.81 0.63 195 accuracy 0.82 1060 macro avg 0.73 0.82 0.76 1060 weighted avg 0.87 0.82 0.84 1060 Confusion Matrix: [[715 150] [ 37 158]] AUC Score: 0.9177708611234624 ============================================================
Results for Support Vector Machine: Accuracy: 0.8207547169811321 Classification Report: precision recall f1-score support 0 0.97 0.81 0.88 865 1 0.51 0.87 0.64 195 accuracy 0.82 1060 macro avg 0.74 0.84 0.76 1060 weighted avg 0.88 0.82 0.84 1060 Confusion Matrix: [[700 165] [ 25 170]] AUC Score: 0.9025018526752631 ============================================================
Results for Neural Network: Accuracy: 0.8330188679245283 Classification Report: precision recall f1-score support 0 0.96 0.83 0.89 865 1 0.53 0.83 0.65 195 accuracy 0.83 1060 macro avg 0.74 0.83 0.77 1060 weighted avg 0.88 0.83 0.85 1060 Confusion Matrix: [[722 143] [ 34 161]] AUC Score: 0.9227538165110419 ============================================================
Results for XGBoost: Accuracy: 0.8292452830188679 Classification Report: precision recall f1-score support 0 0.96 0.83 0.89 865 1 0.52 0.84 0.64 195 accuracy 0.83 1060 macro avg 0.74 0.83 0.77 1060 weighted avg 0.88 0.83 0.84 1060 Confusion Matrix: [[716 149] [ 32 163]] AUC Score: 0.9232666370238625 ============================================================
Cross-validation evaluation for each model¶
In [54]:
print("\nCross-validation scores:\n")
for name, model in models.items():
skf = StratifiedKFold(n_splits=5)
cv_scores = cross_val_score(model, X, y, cv=skf)
print(f"{name}: {cv_scores.mean():.3f} (+/- {cv_scores.std():.3f})")
Cross-validation scores: Logistic Regression: 0.795 (+/- 0.012) Decision Tree: 0.870 (+/- 0.013) Random Forest: 0.873 (+/- 0.013) Support Vector Machine: 0.879 (+/- 0.015) Neural Network: 0.885 (+/- 0.015) XGBoost: 0.875 (+/- 0.012)
Final: All in one code shell: Randomized CV¶
In [55]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
import xgboost as xgb
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score, RandomizedSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_curve, roc_auc_score, ConfusionMatrixDisplay
from scipy.stats import uniform, randint
from imblearn.combine import SMOTETomek
# Split the data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
# Apply SMOTETomek to the training data
smote_tomek = SMOTETomek(random_state=42)
X_train_resampled, y_train_resampled = smote_tomek.fit_resample(X_train, y_train)
# Show SMOTETomek result
plt.figure(figsize=(10, 5))
sns.countplot(x=y_train)
plt.title('Class Distribution Before SMOTETomek')
plt.xlabel('Class')
plt.ylabel('Count')
plt.show()
plt.figure(figsize=(10, 5))
sns.countplot(x=y_train_resampled)
plt.title('Class Distribution After SMOTETomek')
plt.xlabel('Class')
plt.ylabel('Count')
plt.show()
# Scale the features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train_resampled)
X_test_scaled = scaler.transform(X_test)
# Define models, including XGBoost
models = {
"Logistic Regression": LogisticRegression(max_iter=1000),
"Decision Tree": DecisionTreeClassifier(),
"Random Forest": RandomForestClassifier(n_estimators=100),
"Support Vector Machine": SVC(probability=True),
"XGBoost": xgb.XGBClassifier(use_label_encoder=False, eval_metric='mlogloss')
}
# Define hyperparameter distributions
lr_params = {
'C': uniform(loc=0.1, scale=10),
'penalty': ['l1', 'l2']
}
dt_params = {
'max_depth': randint(2, 20),
'min_samples_split': randint(2, 10),
'min_samples_leaf': randint(1, 10)
}
rf_params = {
'n_estimators': randint(50, 200),
'max_depth': randint(2, 20),
'min_samples_split': randint(2, 10),
'min_samples_leaf': randint(1, 10)
}
svm_params = {
'C': uniform(loc=0.1, scale=10),
'gamma': uniform(loc=0.001, scale=0.1)
}
xgb_params = {
'n_estimators': randint(50, 200),
'max_depth': randint(2, 10),
'learning_rate': uniform(loc=0.01, scale=0.1)
}
# Function to train and evaluate models with Randomized Search CV
def train_and_evaluate(model, X_train, X_test, y_train, y_test, model_name, param_dist):
# Randomized Search CV
if model_name in models:
rand_search = RandomizedSearchCV(model, param_dist, n_iter=50, cv=5, scoring='roc_auc', random_state=42)
rand_search.fit(X_train, y_train)
model = rand_search.best_estimator_
print(f"Best {model_name} parameters: {rand_search.best_params_}")
else:
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
# Calculate probabilities for ROC AUC
if hasattr(model, "predict_proba"):
y_pred_proba = model.predict_proba(X_test)[:, 1]
else:
y_pred_proba = model.decision_function(X_test)
fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
auc_score = roc_auc_score(y_test, y_pred_proba)
print(f"Results for {model_name}:")
print("Accuracy:", accuracy_score(y_test, y_pred))
print("Classification Report:\n", classification_report(y_test, y_pred))
print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred))
print("AUC Score:", auc_score)
print("\n" + "="*60 + "\n")
# Plot ROC Curve
plt.figure()
plt.plot(fpr, tpr, label=f'ROC curve (area = {auc_score:.2f})')
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title(f'Receiver Operating Characteristic - {model_name}')
plt.legend(loc="lower right")
plt.show()
# Plot Confusion Matrix
cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=model.classes_)
disp.plot(cmap=plt.cm.Blues, values_format='d')
plt.title(f'Confusion Matrix - {model_name}')
plt.show()
# Train and evaluate each model
for name, model in models.items():
if name == "Logistic Regression":
train_and_evaluate(model, X_train_scaled, X_test_scaled, y_train_resampled, y_test, name, lr_params)
elif name == "Decision Tree":
train_and_evaluate(model, X_train_resampled, X_test, y_train_resampled, y_test, name, dt_params)
elif name == "Random Forest":
train_and_evaluate(model, X_train_resampled, X_test, y_train_resampled, y_test, name, rf_params)
elif name == "Support Vector Machine":
train_and_evaluate(model, X_train_scaled, X_test_scaled, y_train_resampled, y_test, name, svm_params)
elif name == "XGBoost":
train_and_evaluate(model, X_train_resampled, X_test, y_train_resampled, y_test, name, xgb_params)
# Cross-validation evaluation for each model
print("\nCross-validation scores:\n")
for name, model in models.items():
skf = StratifiedKFold(n_splits=5)
cv_scores = cross_val_score(model, X, y, cv=skf)
print(f"{name}: {cv_scores.mean():.3f} (+/- {cv_scores.std():.3f})")
Best Logistic Regression parameters: {'C': 0.10778765841014329, 'penalty': 'l2'} Results for Logistic Regression: Accuracy: 0.7566037735849057 Classification Report: precision recall f1-score support 0 0.95 0.74 0.83 865 1 0.42 0.81 0.55 195 accuracy 0.76 1060 macro avg 0.68 0.78 0.69 1060 weighted avg 0.85 0.76 0.78 1060 Confusion Matrix: [[644 221] [ 37 158]] AUC Score: 0.8245531347265451 ============================================================
D:\anaconda\Lib\site-packages\sklearn\model_selection\_validation.py:378: FitFailedWarning: 90 fits failed out of a total of 250. The score on these train-test partitions for these parameters will be set to nan. If these failures are not expected, you can try to debug them by setting error_score='raise'. Below are more details about the failures: -------------------------------------------------------------------------------- 90 fits failed with the following error: Traceback (most recent call last): File "D:\anaconda\Lib\site-packages\sklearn\model_selection\_validation.py", line 686, in _fit_and_score estimator.fit(X_train, y_train, **fit_params) File "D:\anaconda\Lib\site-packages\sklearn\linear_model\_logistic.py", line 1162, in fit solver = _check_solver(self.solver, self.penalty, self.dual) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\anaconda\Lib\site-packages\sklearn\linear_model\_logistic.py", line 54, in _check_solver raise ValueError( ValueError: Solver lbfgs supports only 'l2' or 'none' penalties, got l1 penalty. warnings.warn(some_fits_failed_message, FitFailedWarning) D:\anaconda\Lib\site-packages\sklearn\model_selection\_search.py:952: UserWarning: One or more of the test scores are non-finite: [ nan 0.792131 nan nan 0.79234336 0.79217958 0.79221477 0.79233121 0.79221728 0.79297543 0.792131 nan nan nan 0.79211047 nan nan 0.79221477 0.7922022 nan nan 0.79220178 0.79221603 0.7922043 nan 0.792211 nan 0.7922043 0.79235006 0.7922022 0.79221477 0.79220178 0.7922022 0.79243677 0.79221603 0.7922022 0.7922043 0.79218 nan nan 0.7922022 0.7922022 0.79220095 nan nan nan 0.79221854 0.7922022 nan 0.7921222 ] warnings.warn(
Best Decision Tree parameters: {'max_depth': 17, 'min_samples_leaf': 7, 'min_samples_split': 8} Results for Decision Tree: Accuracy: 0.8264150943396227 Classification Report: precision recall f1-score support 0 0.96 0.83 0.89 865 1 0.52 0.83 0.64 195 accuracy 0.83 1060 macro avg 0.74 0.83 0.76 1060 weighted avg 0.88 0.83 0.84 1060 Confusion Matrix: [[714 151] [ 33 162]] AUC Score: 0.920329035126723 ============================================================
Best Random Forest parameters: {'max_depth': 15, 'min_samples_leaf': 2, 'min_samples_split': 3, 'n_estimators': 58} Results for Random Forest: Accuracy: 0.8320754716981132 Classification Report: precision recall f1-score support 0 0.95 0.83 0.89 865 1 0.53 0.83 0.64 195 accuracy 0.83 1060 macro avg 0.74 0.83 0.77 1060 weighted avg 0.88 0.83 0.84 1060 Confusion Matrix: [[721 144] [ 34 161]] AUC Score: 0.9250037053505262 ============================================================
Best Support Vector Machine parameters: {'C': 3.845401188473625, 'gamma': 0.09607143064099162} Results for Support Vector Machine: Accuracy: 0.8141509433962264 Classification Report: precision recall f1-score support 0 0.97 0.80 0.88 865 1 0.50 0.88 0.63 195 accuracy 0.81 1060 macro avg 0.73 0.84 0.75 1060 weighted avg 0.88 0.81 0.83 1060 Confusion Matrix: [[692 173] [ 24 171]] AUC Score: 0.8974714688009485 ============================================================
Best XGBoost parameters: {'learning_rate': 0.06396921323890797, 'max_depth': 9, 'n_estimators': 173} Results for XGBoost: Accuracy: 0.8273584905660377 Classification Report: precision recall f1-score support 0 0.96 0.83 0.89 865 1 0.52 0.83 0.64 195 accuracy 0.83 1060 macro avg 0.74 0.83 0.76 1060 weighted avg 0.88 0.83 0.84 1060 Confusion Matrix: [[715 150] [ 33 162]] AUC Score: 0.9235808507484808 ============================================================
Cross-validation scores: Logistic Regression: 0.795 (+/- 0.012) Decision Tree: 0.869 (+/- 0.013) Random Forest: 0.876 (+/- 0.012) Support Vector Machine: 0.879 (+/- 0.015) XGBoost: 0.875 (+/- 0.012)
In [58]:
import shap
explainer = shap.Explainer(xgb_model)
shap_values = explainer.shap_values(X_test.sample(50, random_state=12345))
In [59]:
shap_values
Out[59]:
array([[-1.93381105e-02, -6.62288442e-02, -2.49365702e-01, -1.38345852e-01, 1.97153926e-01, 1.91319436e-01, -7.21931934e-01, -3.89662886e+00, -2.00946021e+00], [-1.01129413e+00, -5.75815499e-01, -3.17831814e-01, -7.82127380e-01, -1.56952870e+00, -3.33567882e+00, -1.91243529e+00, -1.64480269e-01, -1.47690368e+00], [-3.51403266e-01, 6.27360195e-02, -2.15497445e-02, 2.33477424e-03, -7.06985295e-01, 5.86404204e-02, -1.28654957e+00, 8.29563200e-01, -7.26504922e-01], [-1.02178597e+00, -1.47476268e+00, -7.49799848e-01, -5.49536765e-01, 4.76419717e-01, -3.06754565e+00, -2.19056702e+00, -9.27880049e-01, -1.25566316e+00], [-1.08964421e-01, 3.37765701e-02, -3.71000692e-02, -7.04841018e-02, 1.38014667e-02, -6.11864567e-01, -2.35663557e+00, -5.09058046e+00, 2.52143323e-01], [ 6.74568117e-02, 2.70688143e-02, 4.56348211e-01, -1.05761252e-01, -2.34828368e-01, 8.41898024e-01, -8.76760721e-01, -4.00064290e-01, -2.83433229e-01], [-3.28015029e-01, 1.09333448e-01, 8.92352700e-01, 1.12779357e-01, -2.19668254e-01, 8.18482995e-01, -8.35878551e-01, -3.72705385e-02, -7.28291690e-01], [ 7.73408934e-02, 9.99717191e-02, -1.29771918e-01, -1.55586660e-01, 2.48188265e-02, -6.02049351e-01, 2.21430254e+00, 9.27423835e-01, 7.14338899e-01], [-2.35508129e-01, 2.77866032e-02, -3.93418521e-01, -9.56036672e-02, 1.70172244e-01, -7.39671230e-01, 2.38143229e+00, 7.38609076e-01, 8.41563344e-01], [ 3.10718834e-01, 3.15310173e-02, -3.55146468e-01, -1.05335794e-01, 9.44646001e-01, 9.44702566e-01, -8.99213552e-01, -9.91142318e-02, -5.11548102e-01], [-9.51097682e-02, 3.81136164e-02, -2.51929700e-01, 1.65722221e-02, 2.11887375e-01, -9.84828174e-01, -3.18203539e-01, 1.13615823e+00, 3.54789734e-01], [-7.89452732e-01, -1.14122164e+00, -3.80998313e-01, -4.06694084e-01, 3.64301354e-01, -2.51861525e+00, -1.43628013e+00, -3.76160407e+00, -8.53261173e-01], [-4.62559648e-02, 3.76737826e-02, -9.40762311e-02, -5.38437441e-02, -7.34988078e-02, 1.46637648e-01, -2.65724611e+00, -5.09845924e+00, 2.13134125e-01], [-8.98990482e-02, 1.11257941e-01, 4.62407134e-02, 4.46258225e-02, 7.07178712e-01, 9.51091826e-01, -7.94422209e-01, 6.47840381e-01, -3.02195460e-01], [-1.13522041e+00, -9.49010134e-01, 2.51475833e-02, -8.91322553e-01, -1.09084892e+00, 6.64149642e-01, -1.01224065e+00, -1.95666421e-02, -6.06009364e-01], [ 2.58693665e-01, 4.40996438e-02, 3.99653554e-01, -1.55484423e-01, 4.82581943e-01, 8.58417332e-01, -9.39561903e-01, -2.27498159e-01, -4.97091264e-01], [ 3.05514455e-01, 4.71714810e-02, 2.25309402e-01, 1.22657724e-01, 5.78781784e-01, 9.59272027e-01, -1.04257691e+00, -2.45525539e-02, -2.44668916e-01], [ 7.25928620e-02, 5.28936982e-02, 1.90763354e-01, -1.04034469e-01, 3.05118468e-02, -5.32749951e-01, -2.29144049e+00, -5.48646832e+00, 3.26812178e-01], [-4.85669598e-02, -4.25708368e-02, -4.27143574e-01, -2.74695177e-02, -1.56923199e+00, -2.82138610e+00, -1.82957971e+00, -1.51607239e+00, -1.82401371e+00], [-7.86977291e-01, 2.10416913e-02, -3.64142030e-01, -5.99623919e-01, -7.41931200e-02, 1.46533474e-01, -5.72120249e-01, 1.22435011e-01, 1.10557839e-01], [ 3.04077893e-01, 3.04592792e-02, 7.10107625e-01, -1.75621390e-01, -1.11065888e+00, 5.24037719e-01, -1.47408879e+00, -2.61611968e-01, -9.77721393e-01], [ 2.74994951e-02, 5.25919124e-02, -4.75337915e-02, -8.74799341e-02, -1.66352272e-01, 1.32912964e-01, -2.73681140e+00, -5.07298803e+00, 1.23192310e-01], [-1.93771213e-01, 1.95712503e-02, 9.09444690e-02, -1.21797666e-01, 2.37964407e-01, -2.70368242e+00, -1.53256130e+00, -3.48969340e+00, -6.97732508e-01], [-4.10069585e-01, -4.03013796e-01, 7.24365264e-02, 2.81927347e-01, 3.14463198e-01, -6.74239039e-01, 2.88252163e+00, 7.30022550e-01, 1.42431295e+00], [-1.07249916e+00, -1.13756537e+00, -1.18762247e-01, -4.30933535e-02, 6.33755863e-01, 7.96018243e-01, -1.07325172e+00, 3.98037881e-01, 5.29460013e-02], [ 3.10718834e-01, 3.15310173e-02, -3.55146468e-01, -1.05335794e-01, 9.44646001e-01, 9.44702566e-01, -8.99213552e-01, -9.91142318e-02, -5.11548102e-01], [-1.59351639e-02, 1.26875058e-01, -2.26915643e-01, -2.92278882e-02, -7.29034066e-01, 4.05084997e-01, -1.15267050e+00, -4.07824516e-02, 1.67073321e+00], [-1.16919227e-01, 5.82024045e-02, -1.62954524e-01, -1.19079441e-01, 2.43242055e-01, -8.81494939e-01, 2.84752679e+00, 2.63449073e-01, 4.03350890e-01], [-3.14670265e-01, 1.14170276e-02, -4.43615854e-01, -1.61466256e-01, -1.32154727e+00, 4.33280051e-01, -1.62493134e+00, -4.37575102e-01, -1.02052307e+00], [-2.35888943e-01, 8.52480382e-02, -3.90788078e-01, 1.48218693e-02, -8.04140329e-01, -3.11676788e+00, -1.33595455e+00, 2.26336852e-01, -1.15486991e+00], [-6.94697201e-02, 1.34566948e-02, -2.48693392e-01, -9.73307714e-02, 2.06582591e-01, -3.44762087e-01, -2.32455134e+00, -4.94624043e+00, -2.00959131e-01], [ 1.51706666e-01, 6.80131391e-02, -1.90164283e-01, -2.33701020e-02, 4.23217416e-02, 2.54519045e-01, 1.85870767e-01, 1.59819931e-01, -3.21910828e-02], [-3.19132179e-01, 6.33101240e-02, -2.49963343e-01, 3.97904217e-02, -8.17891285e-02, 2.93164283e-01, -2.41855964e-01, 2.45835111e-01, 5.71214519e-02], [-1.42147943e-01, 9.40644741e-02, -4.91657436e-01, -1.87025324e-01, -1.53674548e-02, -4.11945701e-01, 2.33720660e+00, -1.84624746e-01, 4.14770633e-01], [ 2.74994951e-02, 5.25919124e-02, -4.75337915e-02, -8.74799341e-02, -1.66352272e-01, 1.32912964e-01, -2.73681140e+00, -5.07298803e+00, 1.23192310e-01], [ 7.73408934e-02, 9.99717191e-02, -1.29771918e-01, -1.55586660e-01, 2.48188265e-02, -6.02049351e-01, 2.21430254e+00, 9.27423835e-01, 7.14338899e-01], [-5.17478138e-02, 1.00478930e-02, 2.24559888e-01, -2.14167219e-02, 3.48142356e-01, -2.78471422e+00, -2.01706696e+00, -2.63107419e-01, -2.47578955e+00], [-8.99411663e-02, 3.87514420e-02, -6.75466210e-02, -4.90424484e-02, -6.49665892e-02, 1.34221569e-01, -2.80208254e+00, -5.16013336e+00, 9.94109511e-02], [-1.95364729e-02, 1.04852552e-02, -1.64784491e-01, -1.00198261e-01, -5.47570549e-02, 6.33303598e-02, -2.61272931e+00, -5.31503296e+00, -3.35513987e-02], [ 9.66485590e-02, 3.40564027e-02, -3.73806417e-01, -1.83075756e-01, 1.40359879e-01, -6.58663094e-01, 2.55919194e+00, 6.32796943e-01, 7.43776858e-01], [-3.40972871e-01, 1.50461972e-01, 3.91101241e-01, -1.19617078e-02, -1.03763819e+00, -6.90448403e-01, -1.55245662e+00, -1.04290509e+00, 8.81740332e-01], [-8.99411663e-02, 3.87514420e-02, -6.75466210e-02, -4.90424484e-02, -6.49665892e-02, 1.34221569e-01, -2.80208254e+00, -5.16013336e+00, 9.94109511e-02], [-3.41409564e-01, 5.56582119e-03, -3.45922589e-01, 4.66416776e-03, -1.39189041e+00, -3.02620101e+00, -1.75788784e+00, -9.71480608e-01, -1.72892499e+00], [-3.94320637e-01, 6.14018850e-02, 6.32896602e-01, 2.18729861e-02, 7.87348330e-01, 9.17335033e-01, -9.28744018e-01, 3.17282200e-01, -3.49805593e-01], [ 7.73408934e-02, 9.99717191e-02, -1.29771918e-01, -1.55586660e-01, 2.48188265e-02, -6.02049351e-01, 2.21430254e+00, 9.27423835e-01, 7.14338899e-01], [ 2.93549001e-01, 1.04285046e-01, 3.14950854e-01, -1.12519681e-01, 6.20082855e-01, -1.64532948e+00, -1.40471959e+00, 6.85630023e-01, -7.29807913e-01], [ 1.89440057e-01, 7.47012272e-02, -2.02477023e-01, 4.91003655e-02, 3.18141371e-01, -3.66354656e+00, -2.30577159e+00, -3.63638878e-01, -1.27347314e+00], [ 4.36498821e-02, 8.58186558e-02, 2.13198364e-02, -1.09476358e-01, 1.95139512e-01, -5.55412769e-01, 3.43275619e+00, 4.67949331e-01, -1.04716980e+00], [-8.36332738e-01, -7.53389418e-01, -4.51529622e-01, -5.14398634e-01, 5.66261970e-02, -2.61760235e+00, -2.20572138e+00, -4.91789162e-01, -1.64379048e+00], [ 7.73408934e-02, 9.99717191e-02, -1.29771918e-01, -1.55586660e-01, 2.48188265e-02, -6.02049351e-01, 2.21430254e+00, 9.27423835e-01, 7.14338899e-01]], dtype=float32)
In [65]:
# Create the explainer
explainer = shap.TreeExplainer(xgb_model)
# Generate SHAP values (this returns an Explanation object)
shap_values = explainer(X_test)
# Now you can use the beeswarm plot
shap.plots.beeswarm(shap_values)
In [66]:
X_test.sample(50, random_state=12345).iloc[1]
Out[66]:
TRAFFCTL 2 VISIBILITY 2 LIGHT 3 RDSFCOND 2 DRIVCOND 1 IMPACTYPE 1 INVTYPE 1 INVAGE 3 VEHTYPE 1 Name: 1758, dtype: int64
In [88]:
shap.plots.waterfall(shap_values[1])
Soft Voting Classifier¶
In [69]:
from sklearn.ensemble import VotingClassifier
# Create and train Soft Voting Classifier with Decision Tree included
soft_voting_model = VotingClassifier(estimators=[
('Random Forest', RandomForestClassifier(
n_estimators=63,
max_depth=13,
min_samples_leaf=2,
min_samples_split=3
)),
('XGBoost', xgb.XGBClassifier(
use_label_encoder=False,
eval_metric='mlogloss',
learning_rate=0.035877998160001694,
max_depth=9,
n_estimators=181
)),
('Decision Tree', DecisionTreeClassifier(
max_depth=17,
min_samples_leaf=7,
min_samples_split=8
))
], voting='soft')
# Train the model
soft_voting_model.fit(X_train_resampled, y_train_resampled)
# Evaluate Soft Voting Classifier
print("Soft Voting Classifier Results:")
train_and_evaluate(soft_voting_model, X_train_resampled, X_test, y_train_resampled, y_test, "Soft Voting Classifier", {})
Soft Voting Classifier Results: Results for Soft Voting Classifier: Accuracy: 0.8301886792452831 Classification Report: precision recall f1-score support 0 0.96 0.83 0.89 865 1 0.52 0.84 0.65 195 accuracy 0.83 1060 macro avg 0.74 0.83 0.77 1060 weighted avg 0.88 0.83 0.84 1060 Confusion Matrix: [[716 149] [ 31 164]] AUC Score: 0.9259819178894324 ============================================================
only numeric_transformer, no need categorical_transformer¶
In [70]:
X.dtypes
Out[70]:
TRAFFCTL int64 VISIBILITY int64 LIGHT int64 RDSFCOND int64 DRIVCOND int64 IMPACTYPE int64 INVTYPE int64 INVAGE int64 VEHTYPE int64 dtype: object
In [89]:
# from sklearn.impute import SimpleImputer
# numeric_transformer = Pipeline([
# ("imputer", SimpleImputer(strategy="mean")),
# ("scaler", StandardScaler())
# ])
from sklearn.impute import SimpleImputer
numeric_transformer = Pipeline([
("imputer", SimpleImputer(strategy="mean"))
])
Model deployment: pipeline and pickle¶
In [90]:
pipeline = Pipeline(steps=[
('preprocessor', numeric_transformer),
('voting', VotingClassifier(estimators=[
('Random Forest', RandomForestClassifier(
n_estimators=63,
max_depth=13,
min_samples_leaf=2,
min_samples_split=3
)),
('XGBoost', xgb.XGBClassifier(
use_label_encoder=False,
eval_metric='mlogloss',
learning_rate=0.035877998160001694,
max_depth=9,
n_estimators=181
)),
('Decision Tree', DecisionTreeClassifier(
max_depth=17,
min_samples_leaf=7,
min_samples_split=8
))
], voting='soft'))
])
In [91]:
pipeline.fit(X_train_resampled, y_train_resampled)
pipeline.score(X_test, y_test)
Out[91]:
0.8292452830188679
In [92]:
import joblib
joblib.dump(pipeline, "KSI_model_pipeline_voting_without_scaler.pkl")
Out[92]:
['KSI_model_pipeline_voting_without_scaler.pkl']
In [93]:
pipeline
Out[93]:
Pipeline(steps=[('preprocessor', Pipeline(steps=[('imputer', SimpleImputer())])), ('voting', VotingClassifier(estimators=[('Random Forest', RandomForestClassifier(max_depth=13, min_samples_leaf=2, min_samples_split=3, n_estimators=63)), ('XGBoost', XGBClassifier(base_score=None, booster=None, callbacks=None, colsample_bylevel=None, colsample_bynode=None, colsample... max_cat_threshold=None, max_cat_to_onehot=None, max_delta_step=None, max_depth=9, max_leaves=None, min_child_weight=None, missing=nan, monotone_constraints=None, multi_strategy=None, n_estimators=181, n_jobs=None, num_parallel_tree=None, random_state=None, ...)), ('Decision Tree', DecisionTreeClassifier(max_depth=17, min_samples_leaf=7, min_samples_split=8))], voting='soft'))])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Pipeline(steps=[('preprocessor', Pipeline(steps=[('imputer', SimpleImputer())])), ('voting', VotingClassifier(estimators=[('Random Forest', RandomForestClassifier(max_depth=13, min_samples_leaf=2, min_samples_split=3, n_estimators=63)), ('XGBoost', XGBClassifier(base_score=None, booster=None, callbacks=None, colsample_bylevel=None, colsample_bynode=None, colsample... max_cat_threshold=None, max_cat_to_onehot=None, max_delta_step=None, max_depth=9, max_leaves=None, min_child_weight=None, missing=nan, monotone_constraints=None, multi_strategy=None, n_estimators=181, n_jobs=None, num_parallel_tree=None, random_state=None, ...)), ('Decision Tree', DecisionTreeClassifier(max_depth=17, min_samples_leaf=7, min_samples_split=8))], voting='soft'))])
Pipeline(steps=[('imputer', SimpleImputer())])
SimpleImputer()
VotingClassifier(estimators=[('Random Forest', RandomForestClassifier(max_depth=13, min_samples_leaf=2, min_samples_split=3, n_estimators=63)), ('XGBoost', XGBClassifier(base_score=None, booster=None, callbacks=None, colsample_bylevel=None, colsample_bynode=None, colsample_bytree=None, device=None, early_stopping_rounds=None, enable_categorical=False, eval_metric=... max_cat_threshold=None, max_cat_to_onehot=None, max_delta_step=None, max_depth=9, max_leaves=None, min_child_weight=None, missing=nan, monotone_constraints=None, multi_strategy=None, n_estimators=181, n_jobs=None, num_parallel_tree=None, random_state=None, ...)), ('Decision Tree', DecisionTreeClassifier(max_depth=17, min_samples_leaf=7, min_samples_split=8))], voting='soft')
RandomForestClassifier(max_depth=13, min_samples_leaf=2, min_samples_split=3, n_estimators=63)
XGBClassifier(base_score=None, booster=None, callbacks=None, colsample_bylevel=None, colsample_bynode=None, colsample_bytree=None, device=None, early_stopping_rounds=None, enable_categorical=False, eval_metric='mlogloss', feature_types=None, gamma=None, grow_policy=None, importance_type=None, interaction_constraints=None, learning_rate=0.035877998160001694, max_bin=None, max_cat_threshold=None, max_cat_to_onehot=None, max_delta_step=None, max_depth=9, max_leaves=None, min_child_weight=None, missing=nan, monotone_constraints=None, multi_strategy=None, n_estimators=181, n_jobs=None, num_parallel_tree=None, random_state=None, ...)
DecisionTreeClassifier(max_depth=17, min_samples_leaf=7, min_samples_split=8)
Scoring of the "KSI_model_pipeline.pkl"¶
In [94]:
KSI_model_pipeline = joblib.load("KSI_model_pipeline_voting_without_scaler.pkl")
In [95]:
KSI_model_pipeline
Out[95]:
Pipeline(steps=[('preprocessor', Pipeline(steps=[('imputer', SimpleImputer())])), ('voting', VotingClassifier(estimators=[('Random Forest', RandomForestClassifier(max_depth=13, min_samples_leaf=2, min_samples_split=3, n_estimators=63)), ('XGBoost', XGBClassifier(base_score=None, booster=None, callbacks=None, colsample_bylevel=None, colsample_bynode=None, colsample... max_cat_threshold=None, max_cat_to_onehot=None, max_delta_step=None, max_depth=9, max_leaves=None, min_child_weight=None, missing=nan, monotone_constraints=None, multi_strategy=None, n_estimators=181, n_jobs=None, num_parallel_tree=None, random_state=None, ...)), ('Decision Tree', DecisionTreeClassifier(max_depth=17, min_samples_leaf=7, min_samples_split=8))], voting='soft'))])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Pipeline(steps=[('preprocessor', Pipeline(steps=[('imputer', SimpleImputer())])), ('voting', VotingClassifier(estimators=[('Random Forest', RandomForestClassifier(max_depth=13, min_samples_leaf=2, min_samples_split=3, n_estimators=63)), ('XGBoost', XGBClassifier(base_score=None, booster=None, callbacks=None, colsample_bylevel=None, colsample_bynode=None, colsample... max_cat_threshold=None, max_cat_to_onehot=None, max_delta_step=None, max_depth=9, max_leaves=None, min_child_weight=None, missing=nan, monotone_constraints=None, multi_strategy=None, n_estimators=181, n_jobs=None, num_parallel_tree=None, random_state=None, ...)), ('Decision Tree', DecisionTreeClassifier(max_depth=17, min_samples_leaf=7, min_samples_split=8))], voting='soft'))])
Pipeline(steps=[('imputer', SimpleImputer())])
SimpleImputer()
VotingClassifier(estimators=[('Random Forest', RandomForestClassifier(max_depth=13, min_samples_leaf=2, min_samples_split=3, n_estimators=63)), ('XGBoost', XGBClassifier(base_score=None, booster=None, callbacks=None, colsample_bylevel=None, colsample_bynode=None, colsample_bytree=None, device=None, early_stopping_rounds=None, enable_categorical=False, eval_metric=... max_cat_threshold=None, max_cat_to_onehot=None, max_delta_step=None, max_depth=9, max_leaves=None, min_child_weight=None, missing=nan, monotone_constraints=None, multi_strategy=None, n_estimators=181, n_jobs=None, num_parallel_tree=None, random_state=None, ...)), ('Decision Tree', DecisionTreeClassifier(max_depth=17, min_samples_leaf=7, min_samples_split=8))], voting='soft')
RandomForestClassifier(max_depth=13, min_samples_leaf=2, min_samples_split=3, n_estimators=63)
XGBClassifier(base_score=None, booster=None, callbacks=None, colsample_bylevel=None, colsample_bynode=None, colsample_bytree=None, device=None, early_stopping_rounds=None, enable_categorical=False, eval_metric='mlogloss', feature_types=None, gamma=None, grow_policy=None, importance_type=None, interaction_constraints=None, learning_rate=0.035877998160001694, max_bin=None, max_cat_threshold=None, max_cat_to_onehot=None, max_delta_step=None, max_depth=9, max_leaves=None, min_child_weight=None, missing=nan, monotone_constraints=None, multi_strategy=None, n_estimators=181, n_jobs=None, num_parallel_tree=None, random_state=None, ...)
DecisionTreeClassifier(max_depth=17, min_samples_leaf=7, min_samples_split=8)
In [96]:
#'TRAFFCTL', 'VISIBILITY', 'LIGHT', 'RDSFCOND', 'ACCLASS', 'IMPACTYPE', 'INVTYPE', 'INVAGE', 'VEHTYPE', 'ALCOHOL'
KSI_features = new_df[['TRAFFCTL', 'VISIBILITY', 'LIGHT', 'RDSFCOND', 'DRIVCOND','IMPACTYPE', 'INVTYPE', 'INVAGE', 'VEHTYPE']]
In [97]:
# KSI_model_pipeline.predict_proba(KSI_features)[:5]
In [98]:
ksi_to_score = pd.DataFrame({
"TRAFFCTL": [1, 1, 2],
"VISIBILITY": [1, 2, 1],
"LIGHT": [3, 2, 3],
"RDSFCOND": [2, 2, 1],
'DRIVCOND': [2, 2, 1],
"IMPACTYPE": [2, 2, 1],
"INVTYPE": [1, 1, 2],
"INVAGE": [2, 2, 1],
"VEHTYPE": [1, 1, 10]
})
ksi_to_score.head()
Out[98]:
TRAFFCTL | VISIBILITY | LIGHT | RDSFCOND | DRIVCOND | IMPACTYPE | INVTYPE | INVAGE | VEHTYPE | |
---|---|---|---|---|---|---|---|---|---|
0 | 1 | 1 | 3 | 2 | 2 | 2 | 1 | 2 | 1 |
1 | 1 | 2 | 2 | 2 | 2 | 2 | 1 | 2 | 1 |
2 | 2 | 1 | 3 | 1 | 1 | 1 | 2 | 1 | 10 |
In [99]:
pd.DataFrame({
"predicted_prob": KSI_model_pipeline.predict_proba(ksi_to_score)[:, 1]
})
Out[99]:
predicted_prob | |
---|---|
0 | 0.705179 |
1 | 0.931160 |
2 | 0.100457 |
ML Model Scoring App¶
In [7]:
import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc
import numpy as np
import pandas as pd
import joblib
### Setup dash app
app = dash.Dash(__name__)
app.title = 'Machine Learning Model Deployment'
server = app.server
### load ML pipeline (or model)
model_pipeline = joblib.load("KSI_model_pipeline_voting_without_scaler.pkl")
### App Layout
app.layout = html.Div([
dbc.Row([html.H3(children='Predict fatality in incidents')]),
dbc.Row([
dbc.Col(html.Label(children='Traffic Control Type:'), width={"order": "first"}),
dbc.Col(dcc.Dropdown(
id='TRAFFCTL',
options=[
{'label': 'No Control', 'value': 1},
{'label': 'Traffic Control Devices', 'value': 2},
{'label': 'Other', 'value': 3}
],
value=1
))
]),
dbc.Row([
dbc.Col(html.Label(children='Vehicle Type:'), width={"order": "first"}),
dbc.Col(dcc.Dropdown(
id='VEHTYPE',
options=[
{'label': 'Small Vehicles', 'value': 1},
{'label': 'Trucks and Vans', 'value': 2},
{'label': 'Public Transit', 'value': 3},
{'label': 'Emergency and Unknown', 'value': 4},
{'label': 'Special Equipment', 'value': 5},
{'label': 'Off-Road', 'value': 6},
{'label': 'Bicycles and Mopeds', 'value': 7},
{'label': 'Motorcycles', 'value': 8},
{'label': 'Rickshaws', 'value': 9},
{'label': 'Others', 'value': 10}
],
value=1
))
]),
dbc.Row([
dbc.Col(html.Label(children='Driver Condition:'), width={"order": "first"}),
dbc.Col(dcc.Dropdown(
id='DRIVCOND',
options=[
{'label': 'Normal', 'value': 1},
{'label': 'Impaired', 'value': 2},
{'label': 'Other', 'value': 3}
],
value=1
))
]),
dbc.Row([
dbc.Col(html.Label(children='Involved Person Age Group:'), width={"order": "first"}),
dbc.Col(dcc.Dropdown(
id='INVAGE',
options=[
{'label': 'Infants and Young Children (0 to 9)', 'value': 1},
{'label': 'Adolescents (10 to 19)', 'value': 2},
{'label': 'Young Adults (20 to 34)', 'value': 3},
{'label': 'Middle-Aged Adults (35 to 49)', 'value': 4},
{'label': 'Older Adults (50 and above)', 'value': 5},
{'label': 'Unknown', 'value': 6}
],
value=1
))
]),
dbc.Row([
dbc.Col(html.Label(children='Visibility:'), width={"order": "first"}),
dbc.Col(dcc.Dropdown(
id='VISIBILITY',
options=[
{'label': 'Clear', 'value': 1},
{'label': 'Adverse Weather', 'value': 2},
{'label': 'Severe Weather', 'value': 3}
],
value=1
))
]),
dbc.Row([
dbc.Col(html.Label(children='Light Condition:'), width={"order": "first"}),
dbc.Col(dcc.Dropdown(
id='LIGHT',
options=[
{'label': 'Daylight', 'value': 1},
{'label': 'Artificial Light', 'value': 2},
{'label': 'Low Light', 'value': 3}
],
value=1
))
]),
dbc.Row([
dbc.Col(html.Label(children='Road Surface Condition:'), width={"order": "first"}),
dbc.Col(dcc.Dropdown(
id='RDSFCOND',
options=[
{'label': 'Dry', 'value': 1},
{'label': 'Wet', 'value': 2},
{'label': 'Slushy/Other', 'value': 3},
{'label': 'Loose Surface', 'value': 4},
{'label': 'Ice', 'value': 5}
],
value=1
))
]),
dbc.Row([
dbc.Col(html.Label(children='Involved Person Type:'), width={"order": "first"}),
dbc.Col(dcc.Dropdown(
id='INVTYPE',
options=[
{'label': 'Drivers', 'value': 1},
{'label': 'Cyclists/Skaters', 'value': 2},
{'label': 'Passengers', 'value': 3},
{'label': 'Pedestrians', 'value': 4},
{'label': 'Vehicle & Property Owners', 'value': 5},
{'label': 'Other/Special Cases', 'value': 6}
],
value=1
))
]),
dbc.Row([
dbc.Col(html.Label(children='Impact Type:'), width={"order": "first"}),
dbc.Col(dcc.Dropdown(
id='IMPACTYPE',
options=[
{'label': 'Collisions Involving Vulnerable Road Users', 'value': 1},
{'label': 'Vehicle-to-Vehicle Collisions', 'value': 2},
{'label': 'Other', 'value': 3}
],
value=1
))
]),
dbc.Row([dbc.Button('Submit', id='submit-val', n_clicks=0, color="primary")]),
html.Br(),
dbc.Row([html.Div(id='prediction output')])
], style={'padding': '0px 0px 0px 150px', 'width': '50%'})
### Callback to produce the model output
@app.callback(
Output('prediction output', 'children'),
Input('submit-val', 'n_clicks'),
State('TRAFFCTL', 'value'),
State('VISIBILITY', 'value'),
State('LIGHT', 'value'),
State('RDSFCOND', 'value'),
State('DRIVCOND', 'value'),
State('IMPACTYPE', 'value'),
State('INVTYPE', 'value'),
State('INVAGE', 'value'),
State('VEHTYPE', 'value')
)
def update_output(n_clicks, traffctl, visibility, light, rdsfcond, drivcond, impactype, invtype, invage, vehtype):
if n_clicks > 0:
# Create a DataFrame with the input values in the correct order
x = pd.DataFrame({
"TRAFFCTL": [traffctl],
"VISIBILITY": [visibility],
"LIGHT": [light],
"RDSFCOND": [rdsfcond],
"DRIVCOND": [drivcond],
"IMPACTYPE": [impactype],
"INVTYPE": [invtype],
"INVAGE": [invage],
"VEHTYPE": [vehtype]
})
# Make the prediction using the loaded model pipeline
prediction = model_pipeline.predict_proba(x)[0]
# Determine if the incident is fatal or nonfatal
if prediction[1] >= 0.5:
incident_type = 'Fatal'
else:
incident_type = 'Nonfatal'
# Format the output to show both the probability and the incident type
#
output = [
f'The incident is predicted to be: {incident_type}',
html.Br(),
f'The probability of a fatality in this incident is {round(prediction[1], 2)}'
]
else:
output = 'Please submit the form to get a prediction.'
return output
### Run the App
if __name__ == '__main__':
app.run(debug=True, port=8001, host='127.0.0.1')
In [ ]: