Class 3: ML Pipelines

Goals when using ML

  1. Understand about the data (data science/ actual science) probability more statistics, maybe fit another examine model parameters, inspect them

  2. understanding about Naive bayes fit different data varies

  3. claims about the learning algorithm run multiple algorithms on the same data possibly multiple data

Basic setup

  1. test train

  2. training parameters

  3. estimator objects

  4. fit model parameters

  5. metrics

  6. cross validation

import pandas as pd
import seaborn as sns
import numpy as np
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import confusion_matrix, classification_report
from sklearn import datasets
iris_df = sns.load_dataset('iris')
sns.pairplot(iris_df, hue='species')
<seaborn.axisgrid.PairGrid at 0x7f0764ddbc90>
../_images/2021-02-03_3_1.png
X,y = datasets.load_iris(return_X_y=True)
X.shape
(150, 4)
y.shape
(150,)
X_train, X_test, y_train, y_test = train_test_split(X,y,)
gnb = GaussianNB()
gnb.__dict__
{'priors': None, 'var_smoothing': 1e-09}
gnb.fit(X_train,y_train)
GaussianNB()
gnb.__dict__
{'priors': None,
 'var_smoothing': 1e-09,
 'classes_': array([0, 1, 2]),
 'n_features_in_': 4,
 'epsilon_': 3.0144610969387765e-09,
 'theta_': array([[4.98108108, 3.39459459, 1.45675676, 0.23243243],
        [5.98571429, 2.78857143, 4.23714286, 1.32285714],
        [6.4625    , 2.95      , 5.45      , 2.045     ]]),
 'var_': array([[0.10964208, 0.12753835, 0.02785975, 0.00813733],
        [0.28351021, 0.09015511, 0.23719184, 0.0343347 ],
        [0.30484375, 0.0795    , 0.245     , 0.073475  ]]),
 'class_count_': array([37., 35., 40.]),
 'class_prior_': array([0.33035714, 0.3125    , 0.35714286])}
X_test[0]
array([7.7, 3. , 6.1, 2.3])
y_pred = gnb.predict(X_test)
y_pred[:5]
array([2, 1, 1, 1, 0])
y_test[:5]
array([2, 1, 1, 1, 0])
confusion_matrix(y_test, y_pred)
array([[13,  0,  0],
       [ 0, 13,  2],
       [ 0,  0, 10]])
gnb.score(X_test,y_test)
0.9473684210526315
gnb2 = GaussianNB(priors=[.5,.25,.25])

gnb2_cv_scores = cross_val_score(gnb2,X_train,y_train)
np.mean(gnb2_cv_scores)
0.9549407114624506
gnb_cv_scores = cross_val_score(gnb,X_train,y_train)
np.mean(gnb_cv_scores)
0.9640316205533598
print(classification_report(y_test,y_pred))
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        13
           1       1.00      0.87      0.93        15
           2       0.83      1.00      0.91        10

    accuracy                           0.95        38
   macro avg       0.94      0.96      0.95        38
weighted avg       0.96      0.95      0.95        38
gnb.predict_proba(X_test)
array([[4.73792773e-295, 1.55911841e-010, 1.00000000e+000],
       [6.23770515e-125, 8.57049198e-001, 1.42950802e-001],
       [5.55205931e-083, 9.99926593e-001, 7.34067020e-005],
       [1.61637364e-093, 9.99840982e-001, 1.59017634e-004],
       [1.00000000e+000, 3.69492326e-018, 8.42177257e-028],
       [5.07071802e-172, 4.69323295e-001, 5.30676705e-001],
       [1.00000000e+000, 1.03206509e-016, 3.09052974e-026],
       [1.00000000e+000, 8.26429542e-015, 1.00074955e-023],
       [0.00000000e+000, 3.08364083e-011, 1.00000000e+000],
       [1.00000000e+000, 1.59915742e-012, 1.23028208e-021],
       [2.20027240e-082, 9.99975452e-001, 2.45475528e-005],
       [4.24272211e-119, 9.94208070e-001, 5.79192999e-003],
       [1.47234459e-232, 7.52172116e-008, 9.99999925e-001],
       [1.52845856e-109, 9.97282266e-001, 2.71773378e-003],
       [3.32020676e-090, 9.99646819e-001, 3.53180554e-004],
       [1.25132538e-218, 5.18097442e-004, 9.99481903e-001],
       [1.00000000e+000, 9.16766286e-019, 4.97510176e-029],
       [6.83044779e-160, 4.65736444e-002, 9.53426356e-001],
       [1.00000000e+000, 1.26993147e-018, 6.42702604e-029],
       [4.42043056e-082, 9.99958643e-001, 4.13567898e-005],
       [9.99999998e-001, 1.64961226e-009, 1.23378912e-019],
       [3.31780599e-153, 7.47612886e-002, 9.25238711e-001],
       [1.00000000e+000, 2.36431699e-016, 8.48907695e-027],
       [5.43387444e-134, 5.86807730e-001, 4.13192270e-001],
       [1.00000000e+000, 3.63442912e-021, 4.29068433e-030],
       [5.25436668e-253, 2.57568626e-006, 9.99997424e-001],
       [1.00000000e+000, 2.06728116e-018, 5.70404980e-029],
       [2.11985079e-071, 9.99980364e-001, 1.96357461e-005],
       [2.00086033e-286, 7.95723865e-009, 9.99999992e-001],
       [3.48702081e-155, 4.61600669e-001, 5.38399331e-001],
       [1.00000000e+000, 3.79706516e-014, 4.60292993e-024],
       [1.00000000e+000, 1.78238648e-018, 1.34923141e-028],
       [7.35311846e-064, 9.99997532e-001, 2.46798684e-006],
       [4.84388680e-310, 8.74379005e-009, 9.99999991e-001],
       [1.91033418e-085, 9.99777667e-001, 2.22333264e-004],
       [1.73402404e-055, 9.99999135e-001, 8.65370124e-007],
       [5.85941709e-235, 1.92478857e-005, 9.99980752e-001],
       [1.00000000e+000, 1.88981922e-018, 9.53670949e-028]])