Classification Tree¶

hangover_5.jpeg

Data for demo

Back to the spell book

1. Load data¶

1.1 Libraries¶

In [1]:
import pandas as pd
import numpy as np
In [2]:
hangover_df = pd.read_csv("hangover_5.csv")
hangover_df.head()
Out[2]:
ID Night Theme Number_of_Drinks Spent Chow Hangover
0 1 Fri 1 2 703 1 0
1 2 Sat 0 8 287 0 1
2 3 Wed 0 3 346 1 0
3 4 Sat 0 1 312 0 1
4 5 Mon 1 5 919 0 1
In [3]:
pd.DataFrame(hangover_df.columns.values, columns = ["Variables"])
Out[3]:
Variables
0 ID
1 Night
2 Theme
3 Number_of_Drinks
4 Spent
5 Chow
6 Hangover

1.2 Check data types¶

In [4]:
hangover_df.dtypes
Out[4]:
ID                   int64
Night               object
Theme                int64
Number_of_Drinks     int64
Spent                int64
Chow                 int64
Hangover             int64
dtype: object

Convert target variable to categorical.

In [5]:
hangover_df.Hangover = hangover_df.Hangover.astype("category")

Convert some predictors to categorical.

In [6]:
hangover_df.Night = hangover_df.Night.astype("category")
In [7]:
hangover_df.Theme = hangover_df.Theme.astype("category")
In [8]:
hangover_df.Chow = hangover_df.Chow.astype("category")

Or to do it altogether:

In [9]:
# cat_cols = ["Night", "Theme", "Chow"]
# hangover_df[cat_cols] = hangover_df[cat_cols].astype('category')

Check data types again.

In [10]:
hangover_df.dtypes
Out[10]:
ID                     int64
Night               category
Theme               category
Number_of_Drinks       int64
Spent                  int64
Chow                category
Hangover            category
dtype: object
In [11]:
night_1 = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
night_2 = ["Weekday", "Weekday", "Mid Week", "Mid Week", "Weekend", "Weekend", "Weekend"]

hangover_df["Week_Night_Type"] = hangover_df["Night"].replace(night_1, night_2)
hangover_df.head()
Out[11]:
ID Night Theme Number_of_Drinks Spent Chow Hangover Week_Night_Type
0 1 Fri 1 2 703 1 0 Weekend
1 2 Sat 0 8 287 0 1 Weekend
2 3 Wed 0 3 346 1 0 Mid Week
3 4 Sat 0 1 312 0 1 Weekend
4 5 Mon 1 5 919 0 1 Weekday

Filter the required variables.

In [12]:
hangover_df = hangover_df.iloc[:, [2, 3, 4, 5, 7, 6]]
hangover_df
Out[12]:
Theme Number_of_Drinks Spent Chow Week_Night_Type Hangover
0 1 2 703 1 Weekend 0
1 0 8 287 0 Weekend 1
2 0 3 346 1 Mid Week 0
3 0 1 312 0 Weekend 1
4 1 5 919 0 Weekday 1
... ... ... ... ... ... ...
1995 1 7 139 1 Weekday 1
1996 0 3 742 1 Mid Week 1
1997 1 7 549 0 Weekday 1
1998 0 1 448 0 Weekday 0
1999 1 4 368 1 Weekday 1

2000 rows × 6 columns

In [13]:
pd.DataFrame(hangover_df.columns.values, columns = ["Variables"])
Out[13]:
Variables
0 Theme
1 Number_of_Drinks
2 Spent
3 Chow
4 Week_Night_Type
5 Hangover

2. Training-Validation split¶

In [14]:
import sklearn
from sklearn.model_selection import train_test_split

sklearn's decision tree function does not work well with strings.

In [15]:
week_night_type_dummies = pd.get_dummies(hangover_df["Week_Night_Type"])
week_night_type_dummies
Out[15]:
Weekend Weekday Mid Week
0 1 0 0
1 1 0 0
2 0 0 1
3 1 0 0
4 0 1 0
... ... ... ...
1995 0 1 0
1996 0 0 1
1997 0 1 0
1998 0 1 0
1999 0 1 0

2000 rows × 3 columns

In [16]:
hangover_df = pd.concat([hangover_df, week_night_type_dummies], axis = 1)
hangover_df.head()
Out[16]:
Theme Number_of_Drinks Spent Chow Week_Night_Type Hangover Weekend Weekday Mid Week
0 1 2 703 1 Weekend 0 1 0 0
1 0 8 287 0 Weekend 1 1 0 0
2 0 3 346 1 Mid Week 0 0 0 1
3 0 1 312 0 Weekend 1 1 0 0
4 1 5 919 0 Weekday 1 0 1 0

Filter required variables. There is no need for 3 dummies for "Week_Night_Type".

In [17]:
hangover_df = hangover_df.iloc[:,[0, 1, 2, 3, 8, 7, 5]]
hangover_df
Out[17]:
Theme Number_of_Drinks Spent Chow Mid Week Weekday Hangover
0 1 2 703 1 0 0 0
1 0 8 287 0 0 0 1
2 0 3 346 1 1 0 0
3 0 1 312 0 0 0 1
4 1 5 919 0 0 1 1
... ... ... ... ... ... ... ...
1995 1 7 139 1 0 1 1
1996 0 3 742 1 1 0 1
1997 1 7 549 0 0 1 1
1998 0 1 448 0 0 1 0
1999 1 4 368 1 0 1 1

2000 rows × 7 columns

In [18]:
X = hangover_df.drop(columns = ["Hangover"])
y = hangover_df["Hangover"].astype("category")
In [19]:
train_X, valid_X, train_y, valid_y = train_test_split(X, y, test_size = 0.3, random_state = 666)
In [20]:
train_X.head()
Out[20]:
Theme Number_of_Drinks Spent Chow Mid Week Weekday
817 0 3 644 1 0 0
1652 1 9 252 0 0 0
1186 1 9 216 0 0 0
1909 1 3 196 0 1 0
87 0 9 519 1 1 0
In [21]:
len(train_X)
Out[21]:
1400
In [22]:
len(train_y)
Out[22]:
1400
In [23]:
train_y.head()
Out[23]:
817     0
1652    1
1186    1
1909    0
87      1
Name: Hangover, dtype: category
Categories (2, int64): [0, 1]
In [24]:
len(train_y.index)
Out[24]:
1400
In [25]:
valid_X.head()
Out[25]:
Theme Number_of_Drinks Spent Chow Mid Week Weekday
1170 0 10 533 1 1 0
1852 0 6 257 0 1 0
1525 0 2 383 0 0 0
1537 0 6 355 0 0 1
127 1 7 46 1 0 1
In [26]:
len(valid_X)
Out[26]:
600
In [27]:
len(valid_y)
Out[27]:
600

3. Decision Tree¶

In [28]:
from sklearn.tree import DecisionTreeClassifier

3.1 Deep tree¶

3.1.1 The tree¶

In [29]:
full_tree = DecisionTreeClassifier(random_state = 666)
In [30]:
full_tree
Out[30]:
DecisionTreeClassifier(random_state=666)
In [31]:
full_tree_fit = full_tree.fit(train_X, train_y)

Plot the tree.

In [32]:
from sklearn import tree

Export the top levels for illustration using max_depth. Export the whole tree if max_depth is excluded.

In [33]:
text_representation = tree.export_text(full_tree, max_depth = 5)
print(text_representation)
|--- feature_1 <= 3.50
|   |--- feature_0 <= 0.50
|   |   |--- feature_2 <= 908.00
|   |   |   |--- feature_4 <= 0.50
|   |   |   |   |--- feature_2 <= 828.00
|   |   |   |   |   |--- feature_3 <= 0.50
|   |   |   |   |   |   |--- truncated branch of depth 11
|   |   |   |   |   |--- feature_3 >  0.50
|   |   |   |   |   |   |--- truncated branch of depth 12
|   |   |   |   |--- feature_2 >  828.00
|   |   |   |   |   |--- feature_2 <= 862.50
|   |   |   |   |   |   |--- class: 1
|   |   |   |   |   |--- feature_2 >  862.50
|   |   |   |   |   |   |--- truncated branch of depth 2
|   |   |   |--- feature_4 >  0.50
|   |   |   |   |--- feature_2 <= 497.00
|   |   |   |   |   |--- feature_2 <= 127.50
|   |   |   |   |   |   |--- truncated branch of depth 2
|   |   |   |   |   |--- feature_2 >  127.50
|   |   |   |   |   |   |--- truncated branch of depth 5
|   |   |   |   |--- feature_2 >  497.00
|   |   |   |   |   |--- feature_2 <= 538.50
|   |   |   |   |   |   |--- class: 1
|   |   |   |   |   |--- feature_2 >  538.50
|   |   |   |   |   |   |--- truncated branch of depth 9
|   |   |--- feature_2 >  908.00
|   |   |   |--- feature_5 <= 0.50
|   |   |   |   |--- class: 1
|   |   |   |--- feature_5 >  0.50
|   |   |   |   |--- feature_1 <= 2.50
|   |   |   |   |   |--- feature_2 <= 961.00
|   |   |   |   |   |   |--- class: 0
|   |   |   |   |   |--- feature_2 >  961.00
|   |   |   |   |   |   |--- class: 1
|   |   |   |   |--- feature_1 >  2.50
|   |   |   |   |   |--- class: 1
|   |--- feature_0 >  0.50
|   |   |--- feature_2 <= 983.50
|   |   |   |--- feature_3 <= 0.50
|   |   |   |   |--- feature_4 <= 0.50
|   |   |   |   |   |--- feature_2 <= 38.00
|   |   |   |   |   |   |--- truncated branch of depth 3
|   |   |   |   |   |--- feature_2 >  38.00
|   |   |   |   |   |   |--- truncated branch of depth 8
|   |   |   |   |--- feature_4 >  0.50
|   |   |   |   |   |--- feature_2 <= 759.50
|   |   |   |   |   |   |--- truncated branch of depth 6
|   |   |   |   |   |--- feature_2 >  759.50
|   |   |   |   |   |   |--- class: 0
|   |   |   |--- feature_3 >  0.50
|   |   |   |   |--- feature_1 <= 2.50
|   |   |   |   |   |--- feature_2 <= 133.50
|   |   |   |   |   |   |--- class: 0
|   |   |   |   |   |--- feature_2 >  133.50
|   |   |   |   |   |   |--- truncated branch of depth 13
|   |   |   |   |--- feature_1 >  2.50
|   |   |   |   |   |--- feature_2 <= 46.00
|   |   |   |   |   |   |--- class: 1
|   |   |   |   |   |--- feature_2 >  46.00
|   |   |   |   |   |   |--- truncated branch of depth 8
|   |   |--- feature_2 >  983.50
|   |   |   |--- feature_1 <= 2.50
|   |   |   |   |--- class: 0
|   |   |   |--- feature_1 >  2.50
|   |   |   |   |--- class: 1
|--- feature_1 >  3.50
|   |--- feature_1 <= 6.50
|   |   |--- feature_3 <= 0.50
|   |   |   |--- feature_4 <= 0.50
|   |   |   |   |--- feature_2 <= 338.00
|   |   |   |   |   |--- feature_2 <= 272.50
|   |   |   |   |   |   |--- truncated branch of depth 13
|   |   |   |   |   |--- feature_2 >  272.50
|   |   |   |   |   |   |--- class: 1
|   |   |   |   |--- feature_2 >  338.00
|   |   |   |   |   |--- feature_2 <= 544.00
|   |   |   |   |   |   |--- truncated branch of depth 8
|   |   |   |   |   |--- feature_2 >  544.00
|   |   |   |   |   |   |--- truncated branch of depth 14
|   |   |   |--- feature_4 >  0.50
|   |   |   |   |--- feature_2 <= 864.00
|   |   |   |   |   |--- feature_2 <= 167.50
|   |   |   |   |   |   |--- truncated branch of depth 5
|   |   |   |   |   |--- feature_2 >  167.50
|   |   |   |   |   |   |--- truncated branch of depth 10
|   |   |   |   |--- feature_2 >  864.00
|   |   |   |   |   |--- feature_2 <= 944.50
|   |   |   |   |   |   |--- class: 1
|   |   |   |   |   |--- feature_2 >  944.50
|   |   |   |   |   |   |--- truncated branch of depth 2
|   |   |--- feature_3 >  0.50
|   |   |   |--- feature_5 <= 0.50
|   |   |   |   |--- feature_2 <= 427.50
|   |   |   |   |   |--- feature_2 <= 132.00
|   |   |   |   |   |   |--- truncated branch of depth 6
|   |   |   |   |   |--- feature_2 >  132.00
|   |   |   |   |   |   |--- truncated branch of depth 10
|   |   |   |   |--- feature_2 >  427.50
|   |   |   |   |   |--- feature_2 <= 994.00
|   |   |   |   |   |   |--- truncated branch of depth 12
|   |   |   |   |   |--- feature_2 >  994.00
|   |   |   |   |   |   |--- class: 0
|   |   |   |--- feature_5 >  0.50
|   |   |   |   |--- feature_2 <= 568.00
|   |   |   |   |   |--- feature_2 <= 137.00
|   |   |   |   |   |   |--- truncated branch of depth 4
|   |   |   |   |   |--- feature_2 >  137.00
|   |   |   |   |   |   |--- truncated branch of depth 15
|   |   |   |   |--- feature_2 >  568.00
|   |   |   |   |   |--- feature_2 <= 946.50
|   |   |   |   |   |   |--- truncated branch of depth 10
|   |   |   |   |   |--- feature_2 >  946.50
|   |   |   |   |   |   |--- class: 0
|   |--- feature_1 >  6.50
|   |   |--- feature_3 <= 0.50
|   |   |   |--- feature_2 <= 74.50
|   |   |   |   |--- class: 1
|   |   |   |--- feature_2 >  74.50
|   |   |   |   |--- feature_2 <= 83.50
|   |   |   |   |   |--- class: 0
|   |   |   |   |--- feature_2 >  83.50
|   |   |   |   |   |--- feature_2 <= 751.00
|   |   |   |   |   |   |--- truncated branch of depth 16
|   |   |   |   |   |--- feature_2 >  751.00
|   |   |   |   |   |   |--- truncated branch of depth 8
|   |   |--- feature_3 >  0.50
|   |   |   |--- feature_2 <= 867.00
|   |   |   |   |--- feature_0 <= 0.50
|   |   |   |   |   |--- feature_2 <= 622.00
|   |   |   |   |   |   |--- truncated branch of depth 11
|   |   |   |   |   |--- feature_2 >  622.00
|   |   |   |   |   |   |--- truncated branch of depth 6
|   |   |   |   |--- feature_0 >  0.50
|   |   |   |   |   |--- feature_2 <= 811.50
|   |   |   |   |   |   |--- truncated branch of depth 18
|   |   |   |   |   |--- feature_2 >  811.50
|   |   |   |   |   |   |--- truncated branch of depth 2
|   |   |   |--- feature_2 >  867.00
|   |   |   |   |--- feature_2 <= 938.50
|   |   |   |   |   |--- feature_2 <= 924.50
|   |   |   |   |   |   |--- truncated branch of depth 4
|   |   |   |   |   |--- feature_2 >  924.50
|   |   |   |   |   |   |--- class: 0
|   |   |   |   |--- feature_2 >  938.50
|   |   |   |   |   |--- feature_4 <= 0.50
|   |   |   |   |   |   |--- class: 1
|   |   |   |   |   |--- feature_4 >  0.50
|   |   |   |   |   |   |--- truncated branch of depth 2

Plot the top 5 levels for illustration using max_depth. Plot the whole tree if max_depth is excluded.

In [34]:
tree.plot_tree(full_tree, feature_names = train_X.columns, max_depth = 5)
Out[34]:
[Text(0.484375, 0.9285714285714286, 'Number_of_Drinks <= 3.5\ngini = 0.489\nsamples = 1400\nvalue = [597, 803]'),
 Text(0.2528409090909091, 0.7857142857142857, 'Theme <= 0.5\ngini = 0.445\nsamples = 416\nvalue = [277, 139]'),
 Text(0.14772727272727273, 0.6428571428571429, 'Spent <= 908.0\ngini = 0.499\nsamples = 203\nvalue = [96, 107]'),
 Text(0.09090909090909091, 0.5, 'Mid Week <= 0.5\ngini = 0.5\nsamples = 187\nvalue = [95, 92]'),
 Text(0.045454545454545456, 0.35714285714285715, 'Spent <= 828.0\ngini = 0.487\nsamples = 129\nvalue = [54, 75]'),
 Text(0.022727272727272728, 0.21428571428571427, 'Chow <= 0.5\ngini = 0.493\nsamples = 118\nvalue = [52, 66]'),
 Text(0.011363636363636364, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.03409090909090909, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.06818181818181818, 0.21428571428571427, 'Spent <= 862.5\ngini = 0.298\nsamples = 11\nvalue = [2, 9]'),
 Text(0.056818181818181816, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.07954545454545454, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.13636363636363635, 0.35714285714285715, 'Spent <= 497.0\ngini = 0.414\nsamples = 58\nvalue = [41, 17]'),
 Text(0.11363636363636363, 0.21428571428571427, 'Spent <= 127.5\ngini = 0.278\nsamples = 24\nvalue = [20, 4]'),
 Text(0.10227272727272728, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.125, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.1590909090909091, 0.21428571428571427, 'Spent <= 538.5\ngini = 0.472\nsamples = 34\nvalue = [21, 13]'),
 Text(0.14772727272727273, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.17045454545454544, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.20454545454545456, 0.5, 'Weekday <= 0.5\ngini = 0.117\nsamples = 16\nvalue = [1, 15]'),
 Text(0.19318181818181818, 0.35714285714285715, 'gini = 0.0\nsamples = 9\nvalue = [0, 9]'),
 Text(0.2159090909090909, 0.35714285714285715, 'Number_of_Drinks <= 2.5\ngini = 0.245\nsamples = 7\nvalue = [1, 6]'),
 Text(0.20454545454545456, 0.21428571428571427, 'Spent <= 961.0\ngini = 0.444\nsamples = 3\nvalue = [1, 2]'),
 Text(0.19318181818181818, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.2159090909090909, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.22727272727272727, 0.21428571428571427, 'gini = 0.0\nsamples = 4\nvalue = [0, 4]'),
 Text(0.35795454545454547, 0.6428571428571429, 'Spent <= 983.5\ngini = 0.255\nsamples = 213\nvalue = [181, 32]'),
 Text(0.3181818181818182, 0.5, 'Chow <= 0.5\ngini = 0.24\nsamples = 208\nvalue = [179, 29]'),
 Text(0.2727272727272727, 0.35714285714285715, 'Mid Week <= 0.5\ngini = 0.18\nsamples = 110\nvalue = [99, 11]'),
 Text(0.25, 0.21428571428571427, 'Spent <= 38.0\ngini = 0.134\nsamples = 83\nvalue = [77, 6]'),
 Text(0.23863636363636365, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.26136363636363635, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.29545454545454547, 0.21428571428571427, 'Spent <= 759.5\ngini = 0.302\nsamples = 27\nvalue = [22, 5]'),
 Text(0.2840909090909091, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.3068181818181818, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.36363636363636365, 0.35714285714285715, 'Number_of_Drinks <= 2.5\ngini = 0.3\nsamples = 98\nvalue = [80, 18]'),
 Text(0.3409090909090909, 0.21428571428571427, 'Spent <= 133.5\ngini = 0.245\nsamples = 63\nvalue = [54, 9]'),
 Text(0.32954545454545453, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.3522727272727273, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.38636363636363635, 0.21428571428571427, 'Spent <= 46.0\ngini = 0.382\nsamples = 35\nvalue = [26, 9]'),
 Text(0.375, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.3977272727272727, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.3977272727272727, 0.5, 'Number_of_Drinks <= 2.5\ngini = 0.48\nsamples = 5\nvalue = [2, 3]'),
 Text(0.38636363636363635, 0.35714285714285715, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'),
 Text(0.4090909090909091, 0.35714285714285715, 'gini = 0.0\nsamples = 3\nvalue = [0, 3]'),
 Text(0.7159090909090909, 0.7857142857142857, 'Number_of_Drinks <= 6.5\ngini = 0.439\nsamples = 984\nvalue = [320, 664]'),
 Text(0.5909090909090909, 0.6428571428571429, 'Chow <= 0.5\ngini = 0.49\nsamples = 429\nvalue = [184, 245]'),
 Text(0.5, 0.5, 'Mid Week <= 0.5\ngini = 0.5\nsamples = 221\nvalue = [113, 108]'),
 Text(0.45454545454545453, 0.35714285714285715, 'Spent <= 338.0\ngini = 0.496\nsamples = 161\nvalue = [73, 88]'),
 Text(0.4318181818181818, 0.21428571428571427, 'Spent <= 272.5\ngini = 0.429\nsamples = 45\nvalue = [14, 31]'),
 Text(0.42045454545454547, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.4431818181818182, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.4772727272727273, 0.21428571428571427, 'Spent <= 544.0\ngini = 0.5\nsamples = 116\nvalue = [59, 57]'),
 Text(0.4659090909090909, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.48863636363636365, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.5454545454545454, 0.35714285714285715, 'Spent <= 864.0\ngini = 0.444\nsamples = 60\nvalue = [40, 20]'),
 Text(0.5227272727272727, 0.21428571428571427, 'Spent <= 167.5\ngini = 0.406\nsamples = 53\nvalue = [38, 15]'),
 Text(0.5113636363636364, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.5340909090909091, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.5681818181818182, 0.21428571428571427, 'Spent <= 944.5\ngini = 0.408\nsamples = 7\nvalue = [2, 5]'),
 Text(0.5568181818181818, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.5795454545454546, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.6818181818181818, 0.5, 'Weekday <= 0.5\ngini = 0.45\nsamples = 208\nvalue = [71, 137]'),
 Text(0.6363636363636364, 0.35714285714285715, 'Spent <= 427.5\ngini = 0.417\nsamples = 152\nvalue = [45, 107]'),
 Text(0.6136363636363636, 0.21428571428571427, 'Spent <= 132.0\ngini = 0.464\nsamples = 63\nvalue = [23, 40]'),
 Text(0.6022727272727273, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.625, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.6590909090909091, 0.21428571428571427, 'Spent <= 994.0\ngini = 0.372\nsamples = 89\nvalue = [22, 67]'),
 Text(0.6477272727272727, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.6704545454545454, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.7272727272727273, 0.35714285714285715, 'Spent <= 568.0\ngini = 0.497\nsamples = 56\nvalue = [26, 30]'),
 Text(0.7045454545454546, 0.21428571428571427, 'Spent <= 137.0\ngini = 0.422\nsamples = 33\nvalue = [10, 23]'),
 Text(0.6931818181818182, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.7159090909090909, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.75, 0.21428571428571427, 'Spent <= 946.5\ngini = 0.423\nsamples = 23\nvalue = [16, 7]'),
 Text(0.7386363636363636, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.7613636363636364, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.8409090909090909, 0.6428571428571429, 'Chow <= 0.5\ngini = 0.37\nsamples = 555\nvalue = [136, 419]'),
 Text(0.7727272727272727, 0.5, 'Spent <= 74.5\ngini = 0.27\nsamples = 286\nvalue = [46, 240]'),
 Text(0.7613636363636364, 0.35714285714285715, 'gini = 0.0\nsamples = 15\nvalue = [0, 15]'),
 Text(0.7840909090909091, 0.35714285714285715, 'Spent <= 83.5\ngini = 0.282\nsamples = 271\nvalue = [46, 225]'),
 Text(0.7727272727272727, 0.21428571428571427, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'),
 Text(0.7954545454545454, 0.21428571428571427, 'Spent <= 751.0\ngini = 0.278\nsamples = 270\nvalue = [45, 225]'),
 Text(0.7840909090909091, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.8068181818181818, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.9090909090909091, 0.5, 'Spent <= 867.0\ngini = 0.445\nsamples = 269\nvalue = [90, 179]'),
 Text(0.8636363636363636, 0.35714285714285715, 'Theme <= 0.5\ngini = 0.471\nsamples = 224\nvalue = [85, 139]'),
 Text(0.8409090909090909, 0.21428571428571427, 'Spent <= 622.0\ngini = 0.405\nsamples = 110\nvalue = [31, 79]'),
 Text(0.8295454545454546, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.8522727272727273, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.8863636363636364, 0.21428571428571427, 'Spent <= 811.5\ngini = 0.499\nsamples = 114\nvalue = [54, 60]'),
 Text(0.875, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.8977272727272727, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.9545454545454546, 0.35714285714285715, 'Spent <= 938.5\ngini = 0.198\nsamples = 45\nvalue = [5, 40]'),
 Text(0.9318181818181818, 0.21428571428571427, 'Spent <= 924.5\ngini = 0.308\nsamples = 21\nvalue = [4, 17]'),
 Text(0.9204545454545454, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.9431818181818182, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.9772727272727273, 0.21428571428571427, 'Mid Week <= 0.5\ngini = 0.08\nsamples = 24\nvalue = [1, 23]'),
 Text(0.9659090909090909, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.9886363636363636, 0.07142857142857142, '\n  (...)  \n')]

Export tree and convert to a picture file.

In [35]:
from sklearn.tree import export_graphviz
In [36]:
dot_data = export_graphviz(full_tree, out_file='full_tree_v2.dot', feature_names = train_X.columns)

Online conversion

Not a very useful visualisation. But can be used for prediction.

Full tree

3.1.2 Probabilities¶

In [37]:
predProb_train = full_tree.predict_proba(train_X)
predProb_train
Out[37]:
array([[1., 0.],
       [0., 1.],
       [0., 1.],
       ...,
       [0., 1.],
       [1., 0.],
       [0., 1.]])
In [38]:
pd.DataFrame(predProb_train, columns = full_tree.classes_)
Out[38]:
0 1
0 1.0 0.0
1 0.0 1.0
2 0.0 1.0
3 1.0 0.0
4 0.0 1.0
... ... ...
1395 1.0 0.0
1396 0.0 1.0
1397 0.0 1.0
1398 1.0 0.0
1399 0.0 1.0

1400 rows × 2 columns

In [39]:
predProb_valid = full_tree.predict_proba(valid_X)
predProb_valid
Out[39]:
array([[0., 1.],
       [1., 0.],
       [1., 0.],
       ...,
       [1., 0.],
       [1., 0.],
       [1., 0.]])
In [40]:
pd.DataFrame(predProb_valid, columns = full_tree.classes_)
Out[40]:
0 1
0 0.0 1.0
1 1.0 0.0
2 1.0 0.0
3 0.0 1.0
4 0.0 1.0
... ... ...
595 1.0 0.0
596 0.0 1.0
597 1.0 0.0
598 1.0 0.0
599 1.0 0.0

600 rows × 2 columns

3.1.3 Predictions¶

In [41]:
train_y_pred = full_tree.predict(train_X)
train_y_pred
Out[41]:
array([0, 1, 1, ..., 1, 0, 1], dtype=int64)
In [42]:
valid_y_pred = full_tree.predict(valid_X)
valid_y_pred
Out[42]:
array([1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1,
       0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1,
       1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1,
       1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0,
       0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1,
       0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0,
       1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1,
       0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1,
       1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0,
       1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0,
       1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1,
       1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0,
       1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0,
       0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1,
       0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1,
       1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0,
       1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1,
       0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1,
       1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0,
       0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
       1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0,
       1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0,
       0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1,
       1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1,
       0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,
       1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0,
       1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
       0, 0, 1, 0, 0, 0], dtype=int64)

3.1.4 Model evaluation¶

Confusion matrix.

In [43]:
from sklearn.metrics import confusion_matrix, accuracy_score

Training set.

In [44]:
confusion_matrix_train = confusion_matrix(train_y, train_y_pred)
confusion_matrix_train
Out[44]:
array([[597,   0],
       [  1, 802]], dtype=int64)
In [45]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

confusion_matrix_train_display = ConfusionMatrixDisplay(confusion_matrix_train, display_labels = full_tree.classes_)
confusion_matrix_train_display.plot()
plt.grid(False)
In [46]:
accuracy_train = accuracy_score(train_y, train_y_pred)
accuracy_train
Out[46]:
0.9992857142857143
In [47]:
from sklearn.metrics import classification_report

print(classification_report(train_y, train_y_pred))
              precision    recall  f1-score   support

           0       1.00      1.00      1.00       597
           1       1.00      1.00      1.00       803

    accuracy                           1.00      1400
   macro avg       1.00      1.00      1.00      1400
weighted avg       1.00      1.00      1.00      1400

Validation set.

In [48]:
from sklearn.metrics import confusion_matrix, accuracy_score

confusion_matrix_valid = confusion_matrix(valid_y, valid_y_pred)
confusion_matrix_valid
Out[48]:
array([[137, 117],
       [138, 208]], dtype=int64)
In [49]:
# from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
# import matplotlib.pyplot as plt

confusion_matrix_valid_display = ConfusionMatrixDisplay(confusion_matrix_valid, display_labels = full_tree.classes_)
confusion_matrix_valid_display.plot()
plt.grid(False)
In [50]:
accuracy_valid = accuracy_score(valid_y, valid_y_pred)
accuracy_valid
Out[50]:
0.575
In [51]:
from sklearn.metrics import classification_report

print(classification_report(valid_y, valid_y_pred))
              precision    recall  f1-score   support

           0       0.50      0.54      0.52       254
           1       0.64      0.60      0.62       346

    accuracy                           0.57       600
   macro avg       0.57      0.57      0.57       600
weighted avg       0.58      0.57      0.58       600

ROC.

In [52]:
from sklearn import metrics
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve
In [53]:
fpr1, tpr1, thresh1 = roc_curve(valid_y, predProb_valid[:,1], pos_label = 1)
In [54]:
import matplotlib.pyplot as plt
plt.style.use("seaborn")

plt.plot(fpr1, tpr1, linestyle = '-', color = "blue", label = "Full Hangover Tree")

# roc curve for tpr = fpr (random line) 
random_probs = [0 for i in range(len(valid_y))]

p_fpr, p_tpr, _ = roc_curve(valid_y, random_probs, pos_label = 1)

plt.plot(p_fpr, p_tpr, linestyle = '--', color = "black", label = "Random Hangover")

# If desired
plt.legend()

plt.title("Full Hangover Tree ROC")

plt.xlabel("False Positive Rate")

plt.ylabel("True Positive Rate")

# to save the plot
# plt.savefig("whatever_name",dpi = 300)
Out[54]:
Text(0, 0.5, 'True Positive Rate')
In [55]:
from sklearn.metrics import roc_auc_score
In [56]:
auc = roc_auc_score(valid_y, predProb_valid[:,1])
auc
Out[56]:
0.571042510582131

3.2 A shallower tree¶

3.2.1 The tree¶

In [57]:
small_tree = DecisionTreeClassifier(random_state = 666, max_depth = 3, min_samples_split = 25, min_samples_leaf = 11)
In [58]:
small_tree_fit = small_tree.fit(train_X, train_y)
In [59]:
text_representation = tree.export_text(small_tree)
print(text_representation)
|--- feature_1 <= 3.50
|   |--- feature_0 <= 0.50
|   |   |--- feature_2 <= 908.00
|   |   |   |--- class: 0
|   |   |--- feature_2 >  908.00
|   |   |   |--- class: 1
|   |--- feature_0 >  0.50
|   |   |--- feature_2 <= 973.50
|   |   |   |--- class: 0
|   |   |--- feature_2 >  973.50
|   |   |   |--- class: 0
|--- feature_1 >  3.50
|   |--- feature_1 <= 6.50
|   |   |--- feature_3 <= 0.50
|   |   |   |--- class: 0
|   |   |--- feature_3 >  0.50
|   |   |   |--- class: 1
|   |--- feature_1 >  6.50
|   |   |--- feature_3 <= 0.50
|   |   |   |--- class: 1
|   |   |--- feature_3 >  0.50
|   |   |   |--- class: 1

In [60]:
tree.plot_tree(small_tree)
Out[60]:
[Text(0.5, 0.875, 'X[1] <= 3.5\ngini = 0.489\nsamples = 1400\nvalue = [597, 803]'),
 Text(0.25, 0.625, 'X[0] <= 0.5\ngini = 0.445\nsamples = 416\nvalue = [277, 139]'),
 Text(0.125, 0.375, 'X[2] <= 908.0\ngini = 0.499\nsamples = 203\nvalue = [96, 107]'),
 Text(0.0625, 0.125, 'gini = 0.5\nsamples = 187\nvalue = [95, 92]'),
 Text(0.1875, 0.125, 'gini = 0.117\nsamples = 16\nvalue = [1, 15]'),
 Text(0.375, 0.375, 'X[2] <= 973.5\ngini = 0.255\nsamples = 213\nvalue = [181, 32]'),
 Text(0.3125, 0.125, 'gini = 0.239\nsamples = 202\nvalue = [174, 28]'),
 Text(0.4375, 0.125, 'gini = 0.463\nsamples = 11\nvalue = [7, 4]'),
 Text(0.75, 0.625, 'X[1] <= 6.5\ngini = 0.439\nsamples = 984\nvalue = [320, 664]'),
 Text(0.625, 0.375, 'X[3] <= 0.5\ngini = 0.49\nsamples = 429\nvalue = [184, 245]'),
 Text(0.5625, 0.125, 'gini = 0.5\nsamples = 221\nvalue = [113, 108]'),
 Text(0.6875, 0.125, 'gini = 0.45\nsamples = 208\nvalue = [71, 137]'),
 Text(0.875, 0.375, 'X[3] <= 0.5\ngini = 0.37\nsamples = 555\nvalue = [136, 419]'),
 Text(0.8125, 0.125, 'gini = 0.27\nsamples = 286\nvalue = [46, 240]'),
 Text(0.9375, 0.125, 'gini = 0.445\nsamples = 269\nvalue = [90, 179]')]

Export tree and convert to a picture file.

In [61]:
dot_data_2 = export_graphviz(small_tree, out_file='small_tree_v2.dot', feature_names = train_X.columns)

Online conversion

Looks much better!

small_tree_v2.png

3.2.2 Probabilities¶

In [62]:
predProb_train_2 = small_tree.predict_proba(train_X)
predProb_train_2
Out[62]:
array([[0.50802139, 0.49197861],
       [0.16083916, 0.83916084],
       [0.16083916, 0.83916084],
       ...,
       [0.34134615, 0.65865385],
       [0.51131222, 0.48868778],
       [0.51131222, 0.48868778]])
In [63]:
predProb_train_2_df = pd.DataFrame(predProb_train_2, columns = small_tree.classes_)
predProb_train_2_df
Out[63]:
0 1
0 0.508021 0.491979
1 0.160839 0.839161
2 0.160839 0.839161
3 0.861386 0.138614
4 0.334572 0.665428
... ... ...
1395 0.508021 0.491979
1396 0.160839 0.839161
1397 0.341346 0.658654
1398 0.511312 0.488688
1399 0.511312 0.488688

1400 rows × 2 columns

In [64]:
predProb_valid_2 = small_tree.predict_proba(valid_X)
predProb_valid_2
Out[64]:
array([[0.33457249, 0.66542751],
       [0.51131222, 0.48868778],
       [0.50802139, 0.49197861],
       ...,
       [0.86138614, 0.13861386],
       [0.16083916, 0.83916084],
       [0.16083916, 0.83916084]])
In [65]:
predProb_valid_2_df = pd.DataFrame(predProb_valid_2, columns = small_tree.classes_)
predProb_valid_2_df
Out[65]:
0 1
0 0.334572 0.665428
1 0.511312 0.488688
2 0.508021 0.491979
3 0.511312 0.488688
4 0.334572 0.665428
... ... ...
595 0.861386 0.138614
596 0.511312 0.488688
597 0.861386 0.138614
598 0.160839 0.839161
599 0.160839 0.839161

600 rows × 2 columns

3.2.3 Predictions¶

In [66]:
train_y_pred_2 = small_tree.predict(train_X)
train_y_pred_2
Out[66]:
array([0, 1, 1, ..., 1, 0, 0], dtype=int64)
In [67]:
valid_y_pred_2 = small_tree.predict(valid_X)
valid_y_pred_2
Out[67]:
array([1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0,
       0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0,
       1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0,
       1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0,
       0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1,
       0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1,
       0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1,
       1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0,
       1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0,
       0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1,
       1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1,
       1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0,
       0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0,
       0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1,
       0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1,
       0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0,
       0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1,
       1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1,
       0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1,
       0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0,
       1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0,
       1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0,
       1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1,
       0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1,
       0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1,
       1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0,
       1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1,
       0, 0, 0, 0, 1, 1], dtype=int64)

3.2.4 Model evaluation¶

Confusion matrix.

In [68]:
from sklearn.metrics import confusion_matrix, accuracy_score

Training set.

In [69]:
confusion_matrix_train_2 = confusion_matrix(train_y, train_y_pred)
confusion_matrix_train_2
Out[69]:
array([[597,   0],
       [  1, 802]], dtype=int64)
In [70]:
# from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
# import matplotlib.pyplot as plt

confusion_matrix_train_2_display = ConfusionMatrixDisplay(confusion_matrix_train_2, display_labels = small_tree.classes_)
confusion_matrix_train_2_display.plot()
plt.grid(False)
In [71]:
accuracy_train_2 = accuracy_score(train_y, train_y_pred_2)
accuracy_train_2
Out[71]:
0.6857142857142857
In [72]:
from sklearn.metrics import classification_report

print(classification_report(train_y, train_y_pred_2))
              precision    recall  f1-score   support

           0       0.63      0.65      0.64       597
           1       0.73      0.71      0.72       803

    accuracy                           0.69      1400
   macro avg       0.68      0.68      0.68      1400
weighted avg       0.69      0.69      0.69      1400

Now for the validation set

In [73]:
confusion_matrix_valid_2 = confusion_matrix(valid_y, valid_y_pred)
confusion_matrix_valid_2
Out[73]:
array([[137, 117],
       [138, 208]], dtype=int64)
In [74]:
# from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
# import matplotlib.pyplot as plt

confusion_matrix_valid_2_display = ConfusionMatrixDisplay(confusion_matrix_valid_2, display_labels = small_tree.classes_)
confusion_matrix_valid_2_display.plot()
plt.grid(False)
In [75]:
# from sklearn.metrics import classification_report

print(classification_report(valid_y, valid_y_pred_2))
              precision    recall  f1-score   support

           0       0.59      0.61      0.60       254
           1       0.71      0.68      0.70       346

    accuracy                           0.66       600
   macro avg       0.65      0.65      0.65       600
weighted avg       0.66      0.66      0.66       600

Change cutoff to 0.7

New confusion matrix for the training set

In [76]:
train_y_pred_2_df = pd.DataFrame(train_y_pred_2, columns = ["pred_50"])
train_y_pred_2_df
Out[76]:
pred_50
0 0
1 1
2 1
3 0
4 1
... ...
1395 0
1396 1
1397 1
1398 0
1399 0

1400 rows × 1 columns

In [77]:
predProb_train_2_df
Out[77]:
0 1
0 0.508021 0.491979
1 0.160839 0.839161
2 0.160839 0.839161
3 0.861386 0.138614
4 0.334572 0.665428
... ... ...
1395 0.508021 0.491979
1396 0.160839 0.839161
1397 0.341346 0.658654
1398 0.511312 0.488688
1399 0.511312 0.488688

1400 rows × 2 columns

In [78]:
import numpy as np
train_y_pred_2_df["pred_70"] = np.where(predProb_train_2_df.iloc[:, 1] >= 0.7, 
                                                     1, 0)
train_y_pred_2_df.head()
Out[78]:
pred_50 pred_70
0 0 0
1 1 1
2 1 1
3 0 0
4 1 0

Confusion matrix for training set with new cutoff

In [79]:
confusion_matrix_train_2_70 = confusion_matrix(train_y, train_y_pred_2_df["pred_70"])
confusion_matrix_train_2_70
Out[79]:
array([[550,  47],
       [548, 255]], dtype=int64)
In [80]:
confusion_matrix_train_2_70_display = ConfusionMatrixDisplay(confusion_matrix_train_2_70, 
                                                               display_labels = small_tree.classes_)
confusion_matrix_train_2_70_display.plot()
plt.grid(False)

New confusion matrix for the validation set

In [81]:
valid_y_pred_2_df = pd.DataFrame(valid_y_pred_2, columns = ["pred_50"])
valid_y_pred_2_df
Out[81]:
pred_50
0 1
1 0
2 0
3 0
4 1
... ...
595 0
596 0
597 0
598 1
599 1

600 rows × 1 columns

In [82]:
predProb_valid_2_df
Out[82]:
0 1
0 0.334572 0.665428
1 0.511312 0.488688
2 0.508021 0.491979
3 0.511312 0.488688
4 0.334572 0.665428
... ... ...
595 0.861386 0.138614
596 0.511312 0.488688
597 0.861386 0.138614
598 0.160839 0.839161
599 0.160839 0.839161

600 rows × 2 columns

In [83]:
import numpy as np
valid_y_pred_2_df["pred_70"] = np.where(predProb_valid_2_df.iloc[:, 1] >= 0.7, 
                                                     1, 0)
valid_y_pred_2_df.head()
Out[83]:
pred_50 pred_70
0 1 0
1 0 0
2 0 0
3 0 0
4 1 0

Confusion matrix for validation set with new cutoff

In [84]:
confusion_matrix_valid_2_70 = confusion_matrix(valid_y, valid_y_pred_2_df["pred_70"])
confusion_matrix_valid_2_70
Out[84]:
array([[231,  23],
       [233, 113]], dtype=int64)
In [85]:
confusion_matrix_valid_2_70_display = ConfusionMatrixDisplay(confusion_matrix_valid_2_70, 
                                                               display_labels = small_tree.classes_)
confusion_matrix_valid_2_70_display.plot()
plt.grid(False)

ROC

In [86]:
from sklearn import metrics
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve
In [87]:
fpr1, tpr1, thresh1 = roc_curve(valid_y, predProb_valid_2[:,1], pos_label = 1)
In [88]:
import matplotlib.pyplot as plt
plt.style.use("seaborn")

plt.plot(fpr1, tpr1, linestyle = '-', color = "red", label = "Small Hangover Tree")

# roc curve for tpr = fpr (random line) 
random_probs = [0 for i in range(len(valid_y))]

p_fpr, p_tpr, _ = roc_curve(valid_y, random_probs, pos_label = 1)

plt.plot(p_fpr, p_tpr, linestyle = '--', color = "green", label = "Random Hangover")

# If desired
plt.legend()

plt.title("Small Hangover Tree ROC")

plt.xlabel("False Positive Rate")

plt.ylabel("True Positive Rate")

# to save the plot
# plt.savefig("whatever_name",dpi = 300)
Out[88]:
Text(0, 0.5, 'True Positive Rate')
In [89]:
from sklearn.metrics import roc_auc_score
In [90]:
auc = roc_auc_score(valid_y, predProb_valid_2[:,1])
auc
Out[90]:
0.7259569432433662

5. New Record¶

New Record

In [91]:
class_tr_new = pd.read_csv("hangover_new_data.csv")
class_tr_new
Out[91]:
Theme Number_of_Drinks Spent Chow Mid Week Weekday
0 0 6 666 1 0 0
In [92]:
predProb_class_tr_new = small_tree.predict_proba(class_tr_new)
predProb_class_tr_new
Out[92]:
array([[0.34134615, 0.65865385]])
In [93]:
predProb_class_tr_new_df = pd.DataFrame(predProb_class_tr_new, 
                                        columns = small_tree.classes_)
predProb_class_tr_new_df
Out[93]:
0 1
0 0.341346 0.658654
In [94]:
class_tr_new_pred = small_tree.predict(class_tr_new)
class_tr_new_pred
Out[94]:
array([1], dtype=int64)
In [95]:
class_tr_new_pred_df = pd.DataFrame(class_tr_new_pred, columns = ["pred_50"])
class_tr_new_pred_df
Out[95]:
pred_50
0 1
In [96]:
import numpy as np
class_tr_new_pred_df["pred_60"] = np.where(predProb_class_tr_new_df.iloc[:, 1] >= 0.6, 
                                                     1, 0)
class_tr_new_pred_df.head()
Out[96]:
pred_50 pred_60
0 1 1

It's time to retrace your steps... HEH HEH HEH HEH HEH :-)

hangover_logo.jpeg

6. Random Forest¶

In [97]:
from sklearn.ensemble import RandomForestClassifier

rf = RandomForestClassifier(max_depth = 10, random_state = 666)
rf.fit(train_X, train_y)
Out[97]:
RandomForestClassifier(max_depth=10, random_state=666)
In [98]:
train_y_pred_rf = rf.predict(train_X)
train_y_pred_rf
Out[98]:
array([1, 1, 1, ..., 1, 0, 1], dtype=int64)
In [99]:
predProb_train_rf = rf.predict_proba(train_X)
predProb_train_rf
Out[99]:
array([[0.4109716 , 0.5890284 ],
       [0.16733854, 0.83266146],
       [0.27076043, 0.72923957],
       ...,
       [0.21319023, 0.78680977],
       [0.5055186 , 0.4944814 ],
       [0.11442484, 0.88557516]])
In [100]:
pd.DataFrame(predProb_train_rf, columns = rf.classes_)
Out[100]:
0 1
0 0.410972 0.589028
1 0.167339 0.832661
2 0.270760 0.729240
3 0.958139 0.041861
4 0.216078 0.783922
... ... ...
1395 0.603634 0.396366
1396 0.020870 0.979130
1397 0.213190 0.786810
1398 0.505519 0.494481
1399 0.114425 0.885575

1400 rows × 2 columns

In [101]:
valid_y_pred_rf = rf.predict(valid_X)
valid_y_pred_rf
Out[101]:
array([1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0,
       1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1,
       1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0,
       0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1,
       0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1,
       1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1,
       1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0,
       1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0,
       0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0,
       1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0,
       1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1,
       0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1,
       1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1,
       1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
       1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1,
       1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1,
       0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0,
       0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0,
       1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0,
       1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0,
       1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1,
       1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1,
       0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,
       1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0,
       1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
       0, 0, 0, 0, 1, 1], dtype=int64)
In [102]:
predProb_valid_rf = rf.predict_proba(valid_X)
predProb_valid_rf
Out[102]:
array([[0.19484426, 0.80515574],
       [0.61708693, 0.38291307],
       [0.37797144, 0.62202856],
       ...,
       [0.83796426, 0.16203574],
       [0.17397095, 0.82602905],
       [0.07917024, 0.92082976]])
In [103]:
predProb_valid_rf = rf.predict_proba(valid_X)
predProb_valid_rf
Out[103]:
array([[0.19484426, 0.80515574],
       [0.61708693, 0.38291307],
       [0.37797144, 0.62202856],
       ...,
       [0.83796426, 0.16203574],
       [0.17397095, 0.82602905],
       [0.07917024, 0.92082976]])
In [104]:
confusion_matrix_train_rf = confusion_matrix(train_y, train_y_pred_rf)
confusion_matrix_train_rf
Out[104]:
array([[519,  78],
       [ 22, 781]], dtype=int64)
In [105]:
# from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
# import matplotlib.pyplot as plt

confusion_matrix_train_rf_display = ConfusionMatrixDisplay(confusion_matrix_train_rf, display_labels = rf.classes_)
confusion_matrix_train_rf_display.plot()
plt.grid(False)
In [106]:
# from sklearn.metrics import classification_report

print(classification_report(train_y, train_y_pred_rf))
              precision    recall  f1-score   support

           0       0.96      0.87      0.91       597
           1       0.91      0.97      0.94       803

    accuracy                           0.93      1400
   macro avg       0.93      0.92      0.93      1400
weighted avg       0.93      0.93      0.93      1400

In [107]:
confusion_matrix_valid_rf = confusion_matrix(valid_y, valid_y_pred_rf)
confusion_matrix_valid_rf
Out[107]:
array([[129, 125],
       [ 91, 255]], dtype=int64)
In [108]:
# from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
# import matplotlib.pyplot as plt

confusion_matrix_valid_rf_display = ConfusionMatrixDisplay(confusion_matrix_valid_rf, display_labels = rf.classes_)
confusion_matrix_valid_rf_display.plot()
plt.grid(False)
In [109]:
# from sklearn.metrics import classification_report

print(classification_report(valid_y, valid_y_pred_rf))
              precision    recall  f1-score   support

           0       0.59      0.51      0.54       254
           1       0.67      0.74      0.70       346

    accuracy                           0.64       600
   macro avg       0.63      0.62      0.62       600
weighted avg       0.64      0.64      0.64       600

In [110]:
# from sklearn import metrics
# import matplotlib.pyplot as plt
# from sklearn.metrics import roc_curve

fpr1, tpr1, thresh1 = roc_curve(valid_y, predProb_valid_rf[:,1], pos_label = 1)
In [111]:
import matplotlib.pyplot as plt
plt.style.use("seaborn")

plt.plot(fpr1, tpr1, linestyle = '-', color = "orange", label = "Random Forest Hangover")

# roc curve for tpr = fpr (random line) 
random_probs = [0 for i in range(len(valid_y))]

p_fpr, p_tpr, _ = roc_curve(valid_y, random_probs, pos_label = 1)

plt.plot(p_fpr, p_tpr, linestyle = '--', color = "purple", label = "Random Hangover")

# If desired
plt.legend()

plt.title("Random Forest ROC")

plt.xlabel("False Positive Rate")

plt.ylabel("True Positive Rate")

# to save the plot
# plt.savefig("whatever_name",dpi = 300)
Out[111]:
Text(0, 0.5, 'True Positive Rate')
In [112]:
# from sklearn.metrics import roc_auc_score

auc = roc_auc_score(valid_y, predProb_valid_rf[:,1])
auc
Out[112]:
0.6912407264120888
In [113]:
predProb_class_tr_new_rf = rf.predict_proba(class_tr_new)
predProb_class_tr_new_rf
Out[113]:
array([[0.21611204, 0.78388796]])
In [114]:
pd.DataFrame(predProb_class_tr_new_rf, columns = rf.classes_)
Out[114]:
0 1
0 0.216112 0.783888
In [115]:
rf_new_pred = rf.predict(class_tr_new)
rf_new_pred
Out[115]:
array([1], dtype=int64)
In [116]:
rf_new_pred_df = pd.DataFrame(rf_new_pred, columns = ["Prediction"])
rf_new_pred_df
Out[116]:
Prediction
0 1

No matter how we look at it, we're still hungover :-)

hangover_logo.jpeg