Regression Tree¶

dnd-dragons.jpg

Data for demo

Back to the spell book

1. Load Data¶

1.1 Libraries¶

In [1]:
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeRegressor

1.2 Data¶

In [2]:
dnd_df = pd.read_csv("super_heroes_dnd_v3a.csv")
dnd_df.head()
Out[2]:
ID Name Gender Race Height Publisher Alignment Weight STR DEX CON INT WIS CHA Level HP
0 A001 A-Bomb Male Human 203.0 Marvel Comics good 441.0 18 11 17 12 13 11 1 7
1 A002 Abe Sapien Male Icthyo Sapien 191.0 Dark Horse Comics good 65.0 16 17 10 13 15 11 8 72
2 A004 Abomination Male Human / Radiation 203.0 Marvel Comics bad 441.0 13 14 13 10 18 15 15 135
3 A009 Agent 13 Female NaN 173.0 Marvel Comics good 61.0 15 18 16 16 17 10 14 140
4 A015 Alex Mercer Male Human NaN Wildstorm bad NaN 14 17 13 12 10 11 9 72
In [3]:
dnd_df.dtypes
Out[3]:
ID            object
Name          object
Gender        object
Race          object
Height       float64
Publisher     object
Alignment     object
Weight       float64
STR            int64
DEX            int64
CON            int64
INT            int64
WIS            int64
CHA            int64
Level          int64
HP             int64
dtype: object

It's a good idea to get a sense of the target variable

In [4]:
dnd_df["HP"].describe()
Out[4]:
count    734.000000
mean      66.885559
std       36.653877
min        6.000000
25%       36.000000
50%       63.000000
75%       91.000000
max      150.000000
Name: HP, dtype: float64
In [5]:
pd.DataFrame(dnd_df.columns.values, columns = ["variables"])
Out[5]:
variables
0 ID
1 Name
2 Gender
3 Race
4 Height
5 Publisher
6 Alignment
7 Weight
8 STR
9 DEX
10 CON
11 INT
12 WIS
13 CHA
14 Level
15 HP
In [6]:
dnd_df_2 = dnd_df.iloc[:, np.r_[8:14, 15]]
dnd_df_2

# Alternatively, use:
# dnd_df.iloc[:, list(range(8,14)) + [15]]
# Note the end range

# Or just use:
# dnd_df.iloc[:, [8, 9, 10, 11, 12, 13, 15]]

# Or use the variable name range
# dnd_df.loc[:, "STR":"HP"]

# Or specify the variable names
# dnd_df.loc[:, ["STR", "DEX", "CON", "INT", "WIS", "CHA", "HP"]]
Out[6]:
STR DEX CON INT WIS CHA HP
0 18 11 17 12 13 11 7
1 16 17 10 13 15 11 72
2 13 14 13 10 18 15 135
3 15 18 16 16 17 10 140
4 14 17 13 12 10 11 72
... ... ... ... ... ... ... ...
729 8 14 17 13 14 15 64
730 17 12 11 11 14 10 56
731 18 10 14 17 10 10 49
732 11 11 10 12 15 16 36
733 16 12 18 15 15 16 81

734 rows × 7 columns

2. Training-Validation Split¶

In [7]:
import sklearn
from sklearn.model_selection import train_test_split
In [8]:
predictors = ["STR", "DEX", "CON", "INT", "WIS", "CHA"]
outcome = "HP"
In [9]:
X = dnd_df_2.drop(columns = ["HP"])
y = dnd_df_2["HP"]
In [10]:
train_X, valid_X, train_y, valid_y = train_test_split(X, y, test_size = 0.4, random_state = 666)
In [11]:
train_X.head()
Out[11]:
STR DEX CON INT WIS CHA
650 17 14 16 16 15 17
479 8 18 16 10 14 17
271 9 12 17 10 15 17
647 9 18 16 10 17 13
307 12 16 14 18 15 13
In [12]:
len(train_X)
Out[12]:
440
In [13]:
train_y.head()
Out[13]:
650    117
479    120
271     72
647    117
307    100
Name: HP, dtype: int64
In [14]:
len(train_y)
Out[14]:
440
In [15]:
valid_X.head()
Out[15]:
STR DEX CON INT WIS CHA
389 10 16 15 13 11 10
131 18 10 12 10 16 18
657 10 11 12 11 18 14
421 16 13 11 16 13 11
160 12 16 17 18 11 15
In [16]:
len(valid_X)
Out[16]:
294
In [17]:
valid_y.head()
Out[17]:
389    45
131    42
657    63
421    64
160    54
Name: HP, dtype: int64
In [18]:
len(valid_y)
Out[18]:
294

3. Decision Tree¶

3.1 Large tree¶

In [19]:
full_tree = DecisionTreeRegressor(random_state = 666)
full_tree
Out[19]:
DecisionTreeRegressor(random_state=666)
In [20]:
full_tree_fit = full_tree.fit(train_X, train_y)

Plot the tree

In [21]:
from sklearn import tree

Export the top levels for illustration using max_depth. Export the whole tree if max_depth is excluded.

In [22]:
text_representation = tree.export_text(full_tree, max_depth = 5)
print(text_representation)
|--- feature_1 <= 10.50
|   |--- feature_3 <= 14.50
|   |   |--- feature_3 <= 10.50
|   |   |   |--- feature_5 <= 16.00
|   |   |   |   |--- feature_2 <= 13.50
|   |   |   |   |   |--- value: [18.00]
|   |   |   |   |--- feature_2 >  13.50
|   |   |   |   |   |--- feature_0 <= 11.50
|   |   |   |   |   |   |--- value: [48.00]
|   |   |   |   |   |--- feature_0 >  11.50
|   |   |   |   |   |   |--- value: [50.00]
|   |   |   |--- feature_5 >  16.00
|   |   |   |   |--- value: [9.00]
|   |   |--- feature_3 >  10.50
|   |   |   |--- feature_2 <= 13.50
|   |   |   |   |--- feature_0 <= 17.50
|   |   |   |   |   |--- feature_0 <= 12.00
|   |   |   |   |   |   |--- value: [40.00]
|   |   |   |   |   |--- feature_0 >  12.00
|   |   |   |   |   |   |--- truncated branch of depth 4
|   |   |   |   |--- feature_0 >  17.50
|   |   |   |   |   |--- value: [90.00]
|   |   |   |--- feature_2 >  13.50
|   |   |   |   |--- feature_0 <= 10.50
|   |   |   |   |   |--- feature_3 <= 13.50
|   |   |   |   |   |   |--- value: [56.00]
|   |   |   |   |   |--- feature_3 >  13.50
|   |   |   |   |   |   |--- value: [50.00]
|   |   |   |   |--- feature_0 >  10.50
|   |   |   |   |   |--- feature_2 <= 17.50
|   |   |   |   |   |   |--- truncated branch of depth 4
|   |   |   |   |   |--- feature_2 >  17.50
|   |   |   |   |   |   |--- value: [120.00]
|   |--- feature_3 >  14.50
|   |   |--- feature_5 <= 10.50
|   |   |   |--- feature_0 <= 17.50
|   |   |   |   |--- feature_4 <= 16.00
|   |   |   |   |   |--- value: [112.00]
|   |   |   |   |--- feature_4 >  16.00
|   |   |   |   |   |--- value: [84.00]
|   |   |   |--- feature_0 >  17.50
|   |   |   |   |--- value: [49.00]
|   |   |--- feature_5 >  10.50
|   |   |   |--- feature_4 <= 11.50
|   |   |   |   |--- feature_3 <= 16.50
|   |   |   |   |   |--- feature_4 <= 10.50
|   |   |   |   |   |   |--- value: [6.00]
|   |   |   |   |   |--- feature_4 >  10.50
|   |   |   |   |   |   |--- truncated branch of depth 2
|   |   |   |   |--- feature_3 >  16.50
|   |   |   |   |   |--- feature_4 <= 10.50
|   |   |   |   |   |   |--- value: [54.00]
|   |   |   |   |   |--- feature_4 >  10.50
|   |   |   |   |   |   |--- value: [20.00]
|   |   |   |--- feature_4 >  11.50
|   |   |   |   |--- feature_4 <= 17.50
|   |   |   |   |   |--- feature_5 <= 12.50
|   |   |   |   |   |   |--- truncated branch of depth 2
|   |   |   |   |   |--- feature_5 >  12.50
|   |   |   |   |   |   |--- truncated branch of depth 6
|   |   |   |   |--- feature_4 >  17.50
|   |   |   |   |   |--- feature_3 <= 17.50
|   |   |   |   |   |   |--- truncated branch of depth 2
|   |   |   |   |   |--- feature_3 >  17.50
|   |   |   |   |   |   |--- value: [50.00]
|--- feature_1 >  10.50
|   |--- feature_4 <= 17.50
|   |   |--- feature_2 <= 17.50
|   |   |   |--- feature_5 <= 10.50
|   |   |   |   |--- feature_2 <= 12.50
|   |   |   |   |   |--- feature_4 <= 11.50
|   |   |   |   |   |   |--- truncated branch of depth 2
|   |   |   |   |   |--- feature_4 >  11.50
|   |   |   |   |   |   |--- truncated branch of depth 5
|   |   |   |   |--- feature_2 >  12.50
|   |   |   |   |   |--- feature_2 <= 16.50
|   |   |   |   |   |   |--- truncated branch of depth 7
|   |   |   |   |   |--- feature_2 >  16.50
|   |   |   |   |   |   |--- truncated branch of depth 5
|   |   |   |--- feature_5 >  10.50
|   |   |   |   |--- feature_5 <= 17.50
|   |   |   |   |   |--- feature_3 <= 10.50
|   |   |   |   |   |   |--- truncated branch of depth 10
|   |   |   |   |   |--- feature_3 >  10.50
|   |   |   |   |   |   |--- truncated branch of depth 13
|   |   |   |   |--- feature_5 >  17.50
|   |   |   |   |   |--- feature_2 <= 15.50
|   |   |   |   |   |   |--- truncated branch of depth 10
|   |   |   |   |   |--- feature_2 >  15.50
|   |   |   |   |   |   |--- truncated branch of depth 8
|   |   |--- feature_2 >  17.50
|   |   |   |--- feature_1 <= 15.50
|   |   |   |   |--- feature_4 <= 12.50
|   |   |   |   |   |--- feature_0 <= 16.50
|   |   |   |   |   |   |--- truncated branch of depth 4
|   |   |   |   |   |--- feature_0 >  16.50
|   |   |   |   |   |   |--- truncated branch of depth 3
|   |   |   |   |--- feature_4 >  12.50
|   |   |   |   |   |--- feature_0 <= 17.50
|   |   |   |   |   |   |--- truncated branch of depth 5
|   |   |   |   |   |--- feature_0 >  17.50
|   |   |   |   |   |   |--- value: [8.00]
|   |   |   |--- feature_1 >  15.50
|   |   |   |   |--- feature_4 <= 12.50
|   |   |   |   |   |--- feature_3 <= 11.50
|   |   |   |   |   |   |--- truncated branch of depth 2
|   |   |   |   |   |--- feature_3 >  11.50
|   |   |   |   |   |   |--- truncated branch of depth 3
|   |   |   |   |--- feature_4 >  12.50
|   |   |   |   |   |--- feature_1 <= 16.50
|   |   |   |   |   |   |--- truncated branch of depth 4
|   |   |   |   |   |--- feature_1 >  16.50
|   |   |   |   |   |   |--- truncated branch of depth 4
|   |--- feature_4 >  17.50
|   |   |--- feature_0 <= 14.50
|   |   |   |--- feature_3 <= 16.50
|   |   |   |   |--- feature_3 <= 13.50
|   |   |   |   |   |--- feature_1 <= 14.50
|   |   |   |   |   |   |--- truncated branch of depth 6
|   |   |   |   |   |--- feature_1 >  14.50
|   |   |   |   |   |   |--- truncated branch of depth 3
|   |   |   |   |--- feature_3 >  13.50
|   |   |   |   |   |--- feature_1 <= 11.50
|   |   |   |   |   |   |--- truncated branch of depth 2
|   |   |   |   |   |--- feature_1 >  11.50
|   |   |   |   |   |   |--- truncated branch of depth 5
|   |   |   |--- feature_3 >  16.50
|   |   |   |   |--- feature_2 <= 17.00
|   |   |   |   |   |--- feature_3 <= 17.50
|   |   |   |   |   |   |--- value: [72.00]
|   |   |   |   |   |--- feature_3 >  17.50
|   |   |   |   |   |   |--- truncated branch of depth 3
|   |   |   |   |--- feature_2 >  17.00
|   |   |   |   |   |--- value: [117.00]
|   |   |--- feature_0 >  14.50
|   |   |   |--- feature_0 <= 15.50
|   |   |   |   |--- feature_1 <= 14.50
|   |   |   |   |   |--- feature_5 <= 16.00
|   |   |   |   |   |   |--- value: [9.00]
|   |   |   |   |   |--- feature_5 >  16.00
|   |   |   |   |   |   |--- value: [6.00]
|   |   |   |   |--- feature_1 >  14.50
|   |   |   |   |   |--- value: [28.00]
|   |   |   |--- feature_0 >  15.50
|   |   |   |   |--- feature_5 <= 17.50
|   |   |   |   |   |--- feature_2 <= 12.50
|   |   |   |   |   |   |--- truncated branch of depth 3
|   |   |   |   |   |--- feature_2 >  12.50
|   |   |   |   |   |   |--- truncated branch of depth 3
|   |   |   |   |--- feature_5 >  17.50
|   |   |   |   |   |--- value: [112.00]

Plot the top 5 levels for illustration using max_depth. Plot the whole tree if max_depth is excluded.

In [23]:
tree.plot_tree(full_tree, feature_names = train_X.columns, max_depth = 5)
Out[23]:
[Text(0.45454545454545453, 0.9285714285714286, 'DEX <= 10.5\nsquared_error = 1382.015\nsamples = 440\nvalue = 65.552'),
 Text(0.1690340909090909, 0.7857142857142857, 'INT <= 14.5\nsquared_error = 933.256\nsamples = 43\nvalue = 52.0'),
 Text(0.07670454545454546, 0.6428571428571429, 'INT <= 10.5\nsquared_error = 742.63\nsamples = 21\nvalue = 64.476'),
 Text(0.03409090909090909, 0.5, 'CHA <= 16.0\nsquared_error = 325.688\nsamples = 4\nvalue = 31.25'),
 Text(0.022727272727272728, 0.35714285714285715, 'CON <= 13.5\nsquared_error = 214.222\nsamples = 3\nvalue = 38.667'),
 Text(0.011363636363636364, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 18.0'),
 Text(0.03409090909090909, 0.21428571428571427, 'STR <= 11.5\nsquared_error = 1.0\nsamples = 2\nvalue = 49.0'),
 Text(0.022727272727272728, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.045454545454545456, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.045454545454545456, 0.35714285714285715, 'squared_error = 0.0\nsamples = 1\nvalue = 9.0'),
 Text(0.11931818181818182, 0.5, 'CON <= 13.5\nsquared_error = 519.855\nsamples = 17\nvalue = 72.294'),
 Text(0.09090909090909091, 0.35714285714285715, 'STR <= 17.5\nsquared_error = 187.484\nsamples = 8\nvalue = 58.375'),
 Text(0.07954545454545454, 0.21428571428571427, 'STR <= 12.0\nsquared_error = 50.98\nsamples = 7\nvalue = 53.857'),
 Text(0.06818181818181818, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.09090909090909091, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.10227272727272728, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 90.0'),
 Text(0.14772727272727273, 0.35714285714285715, 'STR <= 10.5\nsquared_error = 490.0\nsamples = 9\nvalue = 84.667'),
 Text(0.125, 0.21428571428571427, 'INT <= 13.5\nsquared_error = 9.0\nsamples = 2\nvalue = 53.0'),
 Text(0.11363636363636363, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.13636363636363635, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.17045454545454544, 0.21428571428571427, 'CON <= 17.5\nsquared_error = 259.061\nsamples = 7\nvalue = 93.714'),
 Text(0.1590909090909091, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.18181818181818182, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.26136363636363635, 0.6428571428571429, 'CHA <= 10.5\nsquared_error = 824.81\nsamples = 22\nvalue = 40.091'),
 Text(0.2159090909090909, 0.5, 'STR <= 17.5\nsquared_error = 664.222\nsamples = 3\nvalue = 81.667'),
 Text(0.20454545454545456, 0.35714285714285715, 'WIS <= 16.0\nsquared_error = 196.0\nsamples = 2\nvalue = 98.0'),
 Text(0.19318181818181818, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 112.0'),
 Text(0.2159090909090909, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 84.0'),
 Text(0.22727272727272727, 0.35714285714285715, 'squared_error = 0.0\nsamples = 1\nvalue = 49.0'),
 Text(0.3068181818181818, 0.5, 'WIS <= 11.5\nsquared_error = 534.144\nsamples = 19\nvalue = 33.526'),
 Text(0.26136363636363635, 0.35714285714285715, 'INT <= 16.5\nsquared_error = 319.04\nsamples = 5\nvalue = 19.6'),
 Text(0.23863636363636365, 0.21428571428571427, 'WIS <= 10.5\nsquared_error = 2.667\nsamples = 3\nvalue = 8.0'),
 Text(0.22727272727272727, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.25, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.2840909090909091, 0.21428571428571427, 'WIS <= 10.5\nsquared_error = 289.0\nsamples = 2\nvalue = 37.0'),
 Text(0.2727272727272727, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.29545454545454547, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.3522727272727273, 0.35714285714285715, 'WIS <= 17.5\nsquared_error = 516.964\nsamples = 14\nvalue = 38.5'),
 Text(0.32954545454545453, 0.21428571428571427, 'CHA <= 12.5\nsquared_error = 382.41\nsamples = 10\nvalue = 46.3'),
 Text(0.3181818181818182, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.3409090909090909, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.375, 0.21428571428571427, 'INT <= 17.5\nsquared_error = 321.0\nsamples = 4\nvalue = 19.0'),
 Text(0.36363636363636365, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.38636363636363635, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.7400568181818182, 0.7857142857142857, 'WIS <= 17.5\nsquared_error = 1408.574\nsamples = 397\nvalue = 67.02'),
 Text(0.5795454545454546, 0.6428571428571429, 'CON <= 17.5\nsquared_error = 1407.787\nsamples = 359\nvalue = 68.111'),
 Text(0.48863636363636365, 0.5, 'CHA <= 10.5\nsquared_error = 1409.398\nsamples = 323\nvalue = 69.372'),
 Text(0.4431818181818182, 0.35714285714285715, 'CON <= 12.5\nsquared_error = 1278.057\nsamples = 43\nvalue = 78.419'),
 Text(0.42045454545454547, 0.21428571428571427, 'WIS <= 11.5\nsquared_error = 1262.102\nsamples = 14\nvalue = 57.429'),
 Text(0.4090909090909091, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.4318181818181818, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.4659090909090909, 0.21428571428571427, 'CON <= 16.5\nsquared_error = 970.385\nsamples = 29\nvalue = 88.552'),
 Text(0.45454545454545453, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.4772727272727273, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.5340909090909091, 0.35714285714285715, 'CHA <= 17.5\nsquared_error = 1415.068\nsamples = 280\nvalue = 67.982'),
 Text(0.5113636363636364, 0.21428571428571427, 'INT <= 10.5\nsquared_error = 1402.723\nsamples = 241\nvalue = 66.593'),
 Text(0.5, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.5227272727272727, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.5568181818181818, 0.21428571428571427, 'CON <= 15.5\nsquared_error = 1405.784\nsamples = 39\nvalue = 76.564'),
 Text(0.5454545454545454, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.5681818181818182, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.6704545454545454, 0.5, 'DEX <= 15.5\nsquared_error = 1251.268\nsamples = 36\nvalue = 56.806'),
 Text(0.625, 0.35714285714285715, 'WIS <= 12.5\nsquared_error = 779.741\nsamples = 21\nvalue = 45.857'),
 Text(0.6022727272727273, 0.21428571428571427, 'STR <= 16.5\nsquared_error = 788.29\nsamples = 10\nvalue = 54.9'),
 Text(0.5909090909090909, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.6136363636363636, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.6477272727272727, 0.21428571428571427, 'STR <= 17.5\nsquared_error = 630.05\nsamples = 11\nvalue = 37.636'),
 Text(0.6363636363636364, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.6590909090909091, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.7159090909090909, 0.35714285714285715, 'WIS <= 12.5\nsquared_error = 1508.649\nsamples = 15\nvalue = 72.133'),
 Text(0.6931818181818182, 0.21428571428571427, 'INT <= 11.5\nsquared_error = 584.889\nsamples = 6\nvalue = 37.333'),
 Text(0.6818181818181818, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.7045454545454546, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.7386363636363636, 0.21428571428571427, 'DEX <= 16.5\nsquared_error = 778.889\nsamples = 9\nvalue = 95.333'),
 Text(0.7272727272727273, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.75, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.9005681818181818, 0.6428571428571429, 'STR <= 14.5\nsquared_error = 1298.469\nsamples = 38\nvalue = 56.711'),
 Text(0.8465909090909091, 0.5, 'INT <= 16.5\nsquared_error = 1403.386\nsamples = 26\nvalue = 61.808'),
 Text(0.8068181818181818, 0.35714285714285715, 'INT <= 13.5\nsquared_error = 1391.959\nsamples = 21\nvalue = 54.429'),
 Text(0.7840909090909091, 0.21428571428571427, 'DEX <= 14.5\nsquared_error = 1660.628\nsamples = 11\nvalue = 67.909'),
 Text(0.7727272727272727, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.7954545454545454, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.8295454545454546, 0.21428571428571427, 'DEX <= 11.5\nsquared_error = 676.64\nsamples = 10\nvalue = 39.6'),
 Text(0.8181818181818182, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.8409090909090909, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.8863636363636364, 0.35714285714285715, 'CON <= 17.0\nsquared_error = 262.16\nsamples = 5\nvalue = 92.8'),
 Text(0.875, 0.21428571428571427, 'INT <= 17.5\nsquared_error = 144.688\nsamples = 4\nvalue = 86.75'),
 Text(0.8636363636363636, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.8863636363636364, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.8977272727272727, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 117.0'),
 Text(0.9545454545454546, 0.5, 'STR <= 15.5\nsquared_error = 892.889\nsamples = 12\nvalue = 45.667'),
 Text(0.9318181818181818, 0.35714285714285715, 'DEX <= 14.5\nsquared_error = 76.5\nsamples = 4\nvalue = 13.0'),
 Text(0.9204545454545454, 0.21428571428571427, 'CHA <= 16.0\nsquared_error = 2.0\nsamples = 3\nvalue = 8.0'),
 Text(0.9090909090909091, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.9318181818181818, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.9431818181818182, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 28.0'),
 Text(0.9772727272727273, 0.35714285714285715, 'CHA <= 17.5\nsquared_error = 500.75\nsamples = 8\nvalue = 62.0'),
 Text(0.9659090909090909, 0.21428571428571427, 'CON <= 12.5\nsquared_error = 164.122\nsamples = 7\nvalue = 54.857'),
 Text(0.9545454545454546, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.9772727272727273, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.9886363636363636, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 112.0')]

Export tree and convert to a picture file.

In [24]:
from sklearn.tree import export_graphviz
In [25]:
dot_data = export_graphviz(full_tree, out_file='full_tree.dot', feature_names = train_X.columns)

Not very useful.

full_tree.png

3.2 Small Tree¶

In [26]:
small_tree = DecisionTreeRegressor(random_state = 666, max_depth = 3, min_samples_split = 25)
small_tree
Out[26]:
DecisionTreeRegressor(max_depth=3, min_samples_split=25, random_state=666)
In [27]:
small_tree_fit = small_tree.fit(train_X, train_y)

Plot the tree

In [28]:
# For illustration:
# from sklearn import tree

Export the top levels for illustration using max_depth. Export the whole tree if max_depth is excluded.

In [29]:
text_representation_2 = tree.export_text(small_tree)
print(text_representation_2)
|--- feature_1 <= 10.50
|   |--- feature_3 <= 14.50
|   |   |--- value: [64.48]
|   |--- feature_3 >  14.50
|   |   |--- value: [40.09]
|--- feature_1 >  10.50
|   |--- feature_4 <= 17.50
|   |   |--- feature_2 <= 17.50
|   |   |   |--- value: [69.37]
|   |   |--- feature_2 >  17.50
|   |   |   |--- value: [56.81]
|   |--- feature_4 >  17.50
|   |   |--- feature_0 <= 14.50
|   |   |   |--- value: [61.81]
|   |   |--- feature_0 >  14.50
|   |   |   |--- value: [45.67]

Plot the top 5 levels for illustration using max_depth. Plot the whole tree if max_depth is excluded.

In [30]:
tree.plot_tree(small_tree, feature_names = train_X.columns)
Out[30]:
[Text(0.4090909090909091, 0.875, 'DEX <= 10.5\nsquared_error = 1382.015\nsamples = 440\nvalue = 65.552'),
 Text(0.18181818181818182, 0.625, 'INT <= 14.5\nsquared_error = 933.256\nsamples = 43\nvalue = 52.0'),
 Text(0.09090909090909091, 0.375, 'squared_error = 742.63\nsamples = 21\nvalue = 64.476'),
 Text(0.2727272727272727, 0.375, 'squared_error = 824.81\nsamples = 22\nvalue = 40.091'),
 Text(0.6363636363636364, 0.625, 'WIS <= 17.5\nsquared_error = 1408.574\nsamples = 397\nvalue = 67.02'),
 Text(0.45454545454545453, 0.375, 'CON <= 17.5\nsquared_error = 1407.787\nsamples = 359\nvalue = 68.111'),
 Text(0.36363636363636365, 0.125, 'squared_error = 1409.398\nsamples = 323\nvalue = 69.372'),
 Text(0.5454545454545454, 0.125, 'squared_error = 1251.268\nsamples = 36\nvalue = 56.806'),
 Text(0.8181818181818182, 0.375, 'STR <= 14.5\nsquared_error = 1298.469\nsamples = 38\nvalue = 56.711'),
 Text(0.7272727272727273, 0.125, 'squared_error = 1403.386\nsamples = 26\nvalue = 61.808'),
 Text(0.9090909090909091, 0.125, 'squared_error = 892.889\nsamples = 12\nvalue = 45.667')]

Export tree and convert to a picture file.

In [31]:
# For illustration
# from sklearn.tree import export_graphviz
In [32]:
dot_data_2 = export_graphviz(small_tree, out_file='small_tree.dot', feature_names = train_X.columns)

Much better.

small_tree.png

3.3 Predictions¶

3.3.1 Predictions using the Full Tree¶

In [33]:
train_y_pred_full = full_tree.predict(train_X)
train_y_pred_full
Out[33]:
array([117., 120.,  72., 117., 100.,  90.,  49.,  60.,  42.,  50.,  20.,
        32.,  36.,  80.,  36.,  30., 117.,  99.,  90.,  12.,  42.,  28.,
        40.,  24.,  50.,  98.,  32.,  90.,  50.,  88., 130.,  35.,  18.,
        70.,  99.,  66.,  20., 150.,  56.,   8.,  98.,  54.,  81.,  35.,
         6.,  81.,  56.,  90.,  88.,  70.,  60., 104.,  24.,  81.,  54.,
        60.,  30.,  80.,  84.,  98.,  12., 140., 135.,  56., 135.,  30.,
       117.,  99.,  81., 105.,  42.,  48., 100., 110.,  77.,  84.,  84.,
       104.,  16.,  64.,  48.,  16.,  84.,  18.,  48.,  20.,  24.,  54.,
         9.,  99.,  56., 140.,  72.,  20., 112.,   8., 110., 120.,  35.,
        63.,  21.,  99.,  36.,  72.,  16.,  77., 150.,  50.,  90.,  78.,
        60.,  81., 104.,  45.,  56.,   7.,  10.,  60.,  56.,  96.,  72.,
        28.,  40.,  72.,  78.,  18.,  54., 110.,   8.,  16.,  84., 130.,
        88.,  90.,  54., 100., 110.,  72.,  90.,  81.,   8.,  72.,  30.,
       140., 126., 105.,  36.,  18., 140.,  30.,  32.,  18.,  66.,  63.,
        24.,  78.,  21.,  16.,  32.,   9.,  28., 130.,  42.,  70., 105.,
        56., 135.,  63.,  45.,  72.,  72.,   6., 104.,  64.,  96.,  90.,
        20.,  84.,   7.,  90.,  63.,  42.,  60.,  72.,  49.,   9.,   6.,
        90., 130.,  90.,  90.,  42.,  35.,   9.,  45.,  40., 108.,  21.,
       108.,  30.,  84., 112., 135., 112., 112.,  18.,  84.,  50.,  40.,
         9.,  99.,  81.,  72.,  72.,  30.,  60.,  96.,  27., 140.,  60.,
        90.,  72.,  42.,  72.,  81.,  80., 117.,  32., 135.,   8.,  36.,
        63.,  80.,  16., 120.,  72., 100., 110.,  48.,  42.,  64., 130.,
        48.,  90.,  84.,  54.,  48.,  54.,  18.,  80.,  49.,  84., 150.,
        78., 126.,  63.,   9.,  16.,  50., 120.,   8.,  32.,  56., 135.,
        16.,  77.,  24.,  60.,  48.,  18.,   8.,  70.,  63.,  54.,  91.,
        80., 112.,  70., 120., 120., 120.,   8.,  56.,  12.,  88.,  28.,
        18.,  81.,  48.,  91., 117.,  42.,  49., 140.,  28., 120.,  56.,
       110., 130.,  72.,  18.,  77., 126.,  32.,  42.,  36.,  16.,   9.,
        88.,  54.,  72.,  30., 126.,  88.,  84.,  24.,  60., 117., 104.,
       120.,  77., 105.,  42., 110.,  88.,  56.,  35.,  42.,  80.,  30.,
        50.,  48.,  24.,  21.,  56.,  72.,   9.,  63.,  98.,  60.,  48.,
        16., 117.,  30.,  70., 104.,  49.,  21., 130.,  56., 117.,  78.,
         8.,  36.,  48.,  91.,  84.,  24.,  36.,  72.,  10.,  18.,  36.,
        80.,  90., 112.,  63.,  32.,  96.,  72., 108.,  80.,   9.,  18.,
        98.,  18.,  88.,  20.,  18.,  30.,  12.,  54.,  36.,  42., 120.,
        70.,  32.,   7.,  40.,  63.,  77.,  28.,  24.,   7.,  36.,  48.,
        54.,  10.,  56.,  42., 135.,  98.,  10.,  54.,  84.,  54.,  60.,
       117., 135.,  35., 117.,  72., 130.,  63., 110.,  21.,  81.,  48.,
       110.,  54.,  60.,  49.,  91.,  72.,  48.,  10.,  77.,  72., 112.,
        45., 150.,  88., 150., 135., 140.,  32.,  70.,  80.,  72.,  91.])

Get the RMSE for the training set

In [34]:
mse_full_tree_train = sklearn.metrics.mean_squared_error(train_y, train_y_pred_full)
mse_full_tree_train
Out[34]:
0.0
In [35]:
import math
In [36]:
rmse_full_tree_train = math.sqrt(mse_full_tree_train)
rmse_full_tree_train
Out[36]:
0.0

If using the dmba package, install it first:

pip install dmba or conda install -c conda-forge dmba

import dmba from dmba import regressionSummary

In [37]:
import dmba
from dmba import regressionSummary
In [38]:
regressionSummary(train_y, train_y_pred_full)
Regression statistics

                      Mean Error (ME) : 0.0000
       Root Mean Squared Error (RMSE) : 0.0000
            Mean Absolute Error (MAE) : 0.0000
          Mean Percentage Error (MPE) : 0.0000
Mean Absolute Percentage Error (MAPE) : 0.0000

On the validation set

In [39]:
valid_y_pred_full = full_tree.predict(valid_X)
valid_y_pred_full
Out[39]:
array([110.,   9.,  70., 104.,  84.,  36., 150.,  35.,  90.,  20.,  32.,
       130., 117.,  96., 120.,  21.,  70.,  35.,  18.,  54.,  64.,  27.,
        32.,  72., 126.,  88.,  30.,  20.,  40.,   9., 126., 112.,  21.,
        54., 110.,   9.,  21.,  70., 140., 110.,  72.,  36., 117., 105.,
        72.,  18.,   8.,  54.,  81.,  40.,  36., 135.,  90., 112.,  72.,
        91., 110.,  63., 100.,  27., 110.,   7.,  30.,  90.,  40.,  16.,
        32., 104.,  48.,  90.,  12., 140.,  36., 135.,  16.,  18.,  60.,
        96.,  84.,  54.,  30.,  42.,  80.,  28.,  30.,  16., 130.,   8.,
        90.,  24.,  40., 117.,  49., 140.,  20.,  42.,  72.,  90., 108.,
        42.,  60.,  80., 135.,  48.,  42.,  84.,  91.,  72.,  84.,  36.,
        50.,  50.,  56., 126.,  28.,  42.,  80.,  36.,  96.,  36.,   8.,
        72.,  10.,  72.,  77.,  10., 135.,  54., 140.,  16., 140.,  40.,
         9.,  63.,   8.,  54., 150.,  63.,  56.,  77.,  84., 140., 135.,
         9.,  84.,  36.,   8.,  91.,  16.,  42.,   8., 110.,  56.,  36.,
        20.,  18.,  63., 117.,  81.,  98.,  81.,  84., 120.,  36.,   8.,
        27., 130.,  12.,  63.,   9.,  30.,  84., 130.,  16.,  56.,  90.,
        21.,  63.,  32.,  40.,  40., 140.,  60.,  84.,  18., 100.,  18.,
       126.,  36., 150.,  80.,  77.,  16.,  56.,  88., 110.,  18.,   8.,
        18.,  96.,  42.,  30.,  21., 110.,  54.,  60.,  42., 135., 126.,
       130.,  72., 117.,  72.,  21.,  90.,  70.,  48., 117.,  24.,  60.,
        45.,  91.,   8.,   6., 130.,  81.,  50.,  66.,  96., 130.,  63.,
        63.,   8.,  78.,  70.,  63.,   9.,  72.,  18.,  80.,  54.,  72.,
        56.,  40.,  84.,  60.,  70.,  49.,  64.,  30.,  54.,  28.,  72.,
       104.,  78.,  18.,  56.,  54.,   6.,  77.,  54., 120.,  56.,  54.,
        45.,  56., 126.,  63.,  72.,  80.,  50.,  63., 117.,  63., 112.,
        90., 117.,  36.,  42.,  90., 110., 135.,  56.,  60.,  20.,  56.,
        81.,  84., 150.,   9., 120.,  42.,  72.,  21.])

Get the RMSE for the validation set

In [40]:
mse_full_tree_valid = sklearn.metrics.mean_squared_error(valid_y, valid_y_pred_full)
mse_full_tree_valid
Out[40]:
2643.622448979592
In [41]:
import math
In [42]:
rmse_full_tree_valid = math.sqrt(mse_full_tree_valid)
rmse_full_tree_valid
Out[42]:
51.41616913947977

If using the dmba package, install it first:

pip install dmba or conda install -c conda-forge dmba

import dmba from dmba import regressionSummary

In [43]:
regressionSummary(valid_y, valid_y_pred_full)
Regression statistics

                      Mean Error (ME) : 3.7041
       Root Mean Squared Error (RMSE) : 51.4162
            Mean Absolute Error (MAE) : 41.2415
          Mean Percentage Error (MPE) : -47.2748
Mean Absolute Percentage Error (MAPE) : 99.1131

3.3.2 Predictions using the Small Tree¶

On the training set

In [44]:
train_y_pred = small_tree.predict(train_X)
train_y_pred
Out[44]:
array([69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       61.80769231, 40.09090909, 64.47619048, 69.37151703, 69.37151703,
       69.37151703, 56.80555556, 45.66666667, 69.37151703, 69.37151703,
       69.37151703, 56.80555556, 64.47619048, 69.37151703, 69.37151703,
       69.37151703, 45.66666667, 69.37151703, 56.80555556, 40.09090909,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 61.80769231, 61.80769231, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556,
       56.80555556, 61.80769231, 69.37151703, 69.37151703, 61.80769231,
       69.37151703, 56.80555556, 64.47619048, 64.47619048, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231,
       69.37151703, 61.80769231, 69.37151703, 56.80555556, 69.37151703,
       61.80769231, 61.80769231, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703,
       56.80555556, 69.37151703, 69.37151703, 69.37151703, 56.80555556,
       56.80555556, 69.37151703, 40.09090909, 69.37151703, 56.80555556,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703,
       61.80769231, 69.37151703, 56.80555556, 64.47619048, 56.80555556,
       69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 56.80555556, 40.09090909, 69.37151703, 61.80769231,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703,
       56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       64.47619048, 69.37151703, 69.37151703, 61.80769231, 69.37151703,
       69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703,
       69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 61.80769231, 69.37151703, 69.37151703, 40.09090909,
       64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 64.47619048, 69.37151703, 61.80769231, 40.09090909,
       69.37151703, 64.47619048, 64.47619048, 56.80555556, 69.37151703,
       69.37151703, 56.80555556, 69.37151703, 56.80555556, 69.37151703,
       69.37151703, 69.37151703, 40.09090909, 56.80555556, 69.37151703,
       56.80555556, 69.37151703, 69.37151703, 69.37151703, 64.47619048,
       69.37151703, 69.37151703, 69.37151703, 56.80555556, 40.09090909,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909,
       40.09090909, 64.47619048, 69.37151703, 61.80769231, 61.80769231,
       45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       56.80555556, 69.37151703, 69.37151703, 56.80555556, 69.37151703,
       40.09090909, 69.37151703, 45.66666667, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703,
       69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703,
       61.80769231, 69.37151703, 56.80555556, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 64.47619048, 45.66666667, 45.66666667, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       56.80555556, 69.37151703, 45.66666667, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       64.47619048, 61.80769231, 69.37151703, 69.37151703, 69.37151703,
       64.47619048, 69.37151703, 40.09090909, 69.37151703, 69.37151703,
       69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 56.80555556, 61.80769231, 69.37151703,
       61.80769231, 40.09090909, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 40.09090909, 69.37151703, 64.47619048,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703,
       69.37151703, 64.47619048, 56.80555556, 69.37151703, 64.47619048,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909,
       69.37151703, 61.80769231, 69.37151703, 56.80555556, 69.37151703,
       69.37151703, 69.37151703, 40.09090909, 64.47619048, 69.37151703,
       61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 61.80769231, 69.37151703, 69.37151703, 56.80555556,
       61.80769231, 69.37151703, 69.37151703, 64.47619048, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 45.66666667,
       45.66666667, 69.37151703, 61.80769231, 56.80555556, 56.80555556,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703])

Get the RMSE for the training set

In [45]:
mse_small_tree_train = sklearn.metrics.mean_squared_error(train_y, train_y_pred)
mse_small_tree_train
Out[45]:
1320.9654960245605
In [46]:
import math
In [47]:
rmse_small_tree_train = math.sqrt(mse_small_tree_train)
rmse_small_tree_train
Out[47]:
36.34508902210256

If using the dmba package, install it first:

pip install dmba or conda install -c conda-forge dmba

import dmba from dmba import regressionSummary

In [48]:
import dmba
from dmba import regressionSummary
In [49]:
regressionSummary(train_y, train_y_pred)
Regression statistics

                      Mean Error (ME) : 0.0000
       Root Mean Squared Error (RMSE) : 36.3451
            Mean Absolute Error (MAE) : 30.4192
          Mean Percentage Error (MPE) : -75.0378
Mean Absolute Percentage Error (MAPE) : 103.4166

On the validation set

In [50]:
valid_y_pred = small_tree.predict(valid_X)
valid_y_pred
Out[50]:
array([69.37151703, 64.47619048, 61.80769231, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 61.80769231, 69.37151703, 56.80555556,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703,
       69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703,
       61.80769231, 45.66666667, 61.80769231, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 61.80769231, 56.80555556, 56.80555556,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048,
       56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 56.80555556, 69.37151703, 56.80555556,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909,
       69.37151703, 69.37151703, 69.37151703, 61.80769231, 61.80769231,
       69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703,
       64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703,
       64.47619048, 61.80769231, 56.80555556, 69.37151703, 56.80555556,
       64.47619048, 69.37151703, 69.37151703, 69.37151703, 45.66666667,
       69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       40.09090909, 69.37151703, 61.80769231, 69.37151703, 56.80555556,
       61.80769231, 56.80555556, 69.37151703, 40.09090909, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703,
       69.37151703, 69.37151703, 61.80769231, 64.47619048, 64.47619048,
       56.80555556, 56.80555556, 69.37151703, 69.37151703, 69.37151703,
       40.09090909, 69.37151703, 64.47619048, 69.37151703, 69.37151703,
       69.37151703, 56.80555556, 69.37151703, 64.47619048, 69.37151703,
       45.66666667, 69.37151703, 69.37151703, 56.80555556, 69.37151703,
       40.09090909, 69.37151703, 61.80769231, 69.37151703, 45.66666667,
       69.37151703, 64.47619048, 69.37151703, 69.37151703, 45.66666667,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048,
       64.47619048, 69.37151703, 45.66666667, 56.80555556, 64.47619048,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 61.80769231, 40.09090909, 61.80769231, 69.37151703,
       69.37151703, 61.80769231, 56.80555556, 69.37151703, 64.47619048,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703,
       69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703,
       45.66666667, 40.09090909, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       56.80555556, 69.37151703, 69.37151703, 40.09090909, 69.37151703,
       61.80769231, 69.37151703, 69.37151703, 64.47619048, 69.37151703,
       69.37151703, 61.80769231, 45.66666667, 69.37151703, 69.37151703,
       61.80769231, 45.66666667, 69.37151703, 69.37151703, 56.80555556,
       40.09090909, 64.47619048, 69.37151703, 40.09090909, 69.37151703,
       45.66666667, 69.37151703, 64.47619048, 64.47619048, 61.80769231,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       56.80555556, 69.37151703, 69.37151703, 40.09090909, 40.09090909,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 45.66666667, 69.37151703, 69.37151703,
       69.37151703, 45.66666667, 40.09090909, 69.37151703, 45.66666667,
       64.47619048, 69.37151703, 56.80555556, 69.37151703])

Get the RMSE for the validation set

In [51]:
mse_small_tree_valid = sklearn.metrics.mean_squared_error(valid_y, valid_y_pred)
mse_small_tree_valid
Out[51]:
1353.8427521618253
In [52]:
import math
In [53]:
rmse_small_tree_valid = math.sqrt(mse_small_tree_valid)
rmse_small_tree_valid
Out[53]:
36.79460221502368

If using the dmba package, install it first:

pip install dmba or conda install -c conda-forge dmba

import dmba from dmba import regressionSummary

In [54]:
regressionSummary(valid_y, valid_y_pred)
Regression statistics

                      Mean Error (ME) : 4.2030
       Root Mean Squared Error (RMSE) : 36.7946
            Mean Absolute Error (MAE) : 30.2375
          Mean Percentage Error (MPE) : -48.2303
Mean Absolute Percentage Error (MAPE) : 80.0992

4. Exhaustive Search¶

4.1 Best Tree¶

In [55]:
from sklearn.model_selection import GridSearchCV
In [56]:
param_grid = {"max_depth": [2, 3, 4, 5],
             "min_samples_split": [10, 20, 30],
              "min_impurity_decrease": [0, 0.001, 0.002]}
In [57]:
grid_search = GridSearchCV(DecisionTreeRegressor(random_state = 666), param_grid, cv = 10)
In [58]:
grid_search.fit(train_X, train_y)
Out[58]:
GridSearchCV(cv=10, estimator=DecisionTreeRegressor(random_state=666),
             param_grid={'max_depth': [2, 3, 4, 5],
                         'min_impurity_decrease': [0, 0.001, 0.002],
                         'min_samples_split': [10, 20, 30]})
In [59]:
print("Initial parameters:", grid_search.best_params_)
Initial parameters: {'max_depth': 2, 'min_impurity_decrease': 0, 'min_samples_split': 10}
In [60]:
grid_search.best_score_
Out[60]:
-0.0933578084568057
In [61]:
grid_search.best_params_
Out[61]:
{'max_depth': 2, 'min_impurity_decrease': 0, 'min_samples_split': 10}
In [62]:
best_tree = grid_search.best_estimator_
best_tree
Out[62]:
DecisionTreeRegressor(max_depth=2, min_impurity_decrease=0,
                      min_samples_split=10, random_state=666)
In [63]:
dot_data_3 = export_graphviz(best_tree, out_file='best_tree.dot', feature_names = train_X.columns)

Online Conversion

best_tree.png

4.2 Predictions Using the Best Tree¶

On the training set

In [64]:
train_y_best_pred = best_tree.predict(train_X)
train_y_best_pred
Out[64]:
array([68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       56.71052632, 40.09090909, 64.47619048, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061,
       68.11142061, 56.71052632, 68.11142061, 68.11142061, 40.09090909,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 56.71052632, 56.71052632, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 56.71052632, 68.11142061, 68.11142061, 56.71052632,
       68.11142061, 68.11142061, 64.47619048, 64.47619048, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632,
       68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061,
       56.71052632, 56.71052632, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061,
       56.71052632, 68.11142061, 68.11142061, 64.47619048, 68.11142061,
       68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 40.09090909, 68.11142061, 56.71052632,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       64.47619048, 68.11142061, 68.11142061, 56.71052632, 68.11142061,
       68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 56.71052632, 68.11142061, 68.11142061, 40.09090909,
       64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 64.47619048, 68.11142061, 56.71052632, 40.09090909,
       68.11142061, 64.47619048, 64.47619048, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909,
       40.09090909, 64.47619048, 68.11142061, 56.71052632, 56.71052632,
       56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       40.09090909, 68.11142061, 56.71052632, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061,
       56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 64.47619048, 56.71052632, 56.71052632, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       64.47619048, 56.71052632, 68.11142061, 68.11142061, 68.11142061,
       64.47619048, 68.11142061, 40.09090909, 68.11142061, 68.11142061,
       68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061,
       56.71052632, 40.09090909, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 40.09090909, 68.11142061, 64.47619048,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 64.47619048, 68.11142061, 68.11142061, 64.47619048,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909,
       68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 40.09090909, 64.47619048, 68.11142061,
       56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061,
       56.71052632, 68.11142061, 68.11142061, 64.47619048, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632,
       56.71052632, 68.11142061, 56.71052632, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061])

Get the RMSE for the training set

In [65]:
mse_best_tree_train = sklearn.metrics.mean_squared_error(train_y, train_y_best_pred)
mse_best_tree_train
Out[65]:
1337.4509437318577
In [66]:
import math
In [67]:
rmse_best_tree_train = math.sqrt(mse_best_tree_train)
rmse_best_tree_train
Out[67]:
36.57117640617892
In [68]:
# If using the dmba package, install it first:

# pip install dmba
# or
# conda install -c conda-forge dmba

# import dmba
# from dmba import regressionSummary

regressionSummary(train_y, train_y_best_pred)
Regression statistics

                      Mean Error (ME) : -0.0000
       Root Mean Squared Error (RMSE) : 36.5712
            Mean Absolute Error (MAE) : 30.6315
          Mean Percentage Error (MPE) : -76.2539
Mean Absolute Percentage Error (MAPE) : 104.6671

On the validation set

In [69]:
valid_y_best_pred = best_tree.predict(valid_X)
valid_y_best_pred
Out[69]:
array([68.11142061, 64.47619048, 56.71052632, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       56.71052632, 56.71052632, 56.71052632, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909,
       68.11142061, 68.11142061, 68.11142061, 56.71052632, 56.71052632,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       64.47619048, 56.71052632, 68.11142061, 68.11142061, 68.11142061,
       64.47619048, 68.11142061, 68.11142061, 68.11142061, 56.71052632,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       40.09090909, 68.11142061, 56.71052632, 68.11142061, 68.11142061,
       56.71052632, 68.11142061, 68.11142061, 40.09090909, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061,
       68.11142061, 68.11142061, 56.71052632, 64.47619048, 64.47619048,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       40.09090909, 68.11142061, 64.47619048, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061,
       56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       40.09090909, 68.11142061, 56.71052632, 68.11142061, 56.71052632,
       68.11142061, 64.47619048, 68.11142061, 68.11142061, 56.71052632,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048,
       64.47619048, 68.11142061, 56.71052632, 68.11142061, 64.47619048,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 56.71052632, 40.09090909, 56.71052632, 68.11142061,
       68.11142061, 56.71052632, 68.11142061, 68.11142061, 64.47619048,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061,
       56.71052632, 40.09090909, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061,
       56.71052632, 68.11142061, 68.11142061, 64.47619048, 68.11142061,
       68.11142061, 56.71052632, 56.71052632, 68.11142061, 68.11142061,
       56.71052632, 56.71052632, 68.11142061, 68.11142061, 68.11142061,
       40.09090909, 64.47619048, 68.11142061, 40.09090909, 68.11142061,
       56.71052632, 68.11142061, 64.47619048, 64.47619048, 56.71052632,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 68.11142061, 40.09090909, 40.09090909,
       68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061,
       68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061,
       68.11142061, 56.71052632, 40.09090909, 68.11142061, 56.71052632,
       64.47619048, 68.11142061, 68.11142061, 68.11142061])

Get the RMSE for the validation set

In [70]:
mse_best_tree_valid = sklearn.metrics.mean_squared_error(valid_y, valid_y_best_pred)
mse_best_tree_valid
Out[70]:
1320.469997098233
In [71]:
import math
In [72]:
rmse_best_tree_valid = math.sqrt(mse_best_tree_valid)
rmse_best_tree_valid
Out[72]:
36.33827179570092
In [73]:
# If using the dmba package, install it first:

# pip install dmba
# or
# conda install -c conda-forge dmba

# import dmba
# from dmba import regressionSummary

regressionSummary(valid_y, valid_y_best_pred)
Regression statistics

                      Mean Error (ME) : 3.8074
       Root Mean Squared Error (RMSE) : 36.3383
            Mean Absolute Error (MAE) : 29.9632
          Mean Percentage Error (MPE) : -49.3930
Mean Absolute Percentage Error (MAPE) : 80.5791

5. New Records¶

New records

In [74]:
new_dnd_df = pd.read_csv("new_records_dnd.csv")
new_dnd_df
Out[74]:
STR DEX CON INT WIS CHA
0 9 17 8 13 16 15
1 17 9 17 18 11 7

5.1 Using the small tree¶

In [75]:
new_records_dnd_small_pred = small_tree.predict(new_dnd_df)
new_records_dnd_small_pred
Out[75]:
array([69.37151703, 40.09090909])
In [76]:
import pandas as pd
dnd_small_tree_prediction_df = pd.DataFrame(new_records_dnd_small_pred,
                                         columns = ["Prediction"])
dnd_small_tree_prediction_df
Out[76]:
Prediction
0 69.371517
1 40.090909

Merge with new data

In [77]:
new_dnd_df_with_prediction_small_tree = pd.concat((new_dnd_df, dnd_small_tree_prediction_df), axis = 1)
new_dnd_df_with_prediction_small_tree

# to export
# new_dnd_df_with_prediction_small_tree.to_csv("whatever_name.csv")
Out[77]:
STR DEX CON INT WIS CHA Prediction
0 9 17 8 13 16 15 69.371517
1 17 9 17 18 11 7 40.090909

5.2 Using the best tree¶

In [78]:
new_records_dnd_best_pred = best_tree.predict(new_dnd_df)
new_records_dnd_best_pred
Out[78]:
array([68.11142061, 40.09090909])
In [79]:
dnd_best_tree_prediction_df = pd.DataFrame(new_records_dnd_best_pred,
                                         columns = ["Prediction"])
dnd_best_tree_prediction_df

# to export
# dnd_best_tree_prediction.to_csv("whatever_name.csv")
Out[79]:
Prediction
0 68.111421
1 40.090909

Merge with new data

In [80]:
new_dnd_df_with_prediction_best_tree = pd.concat((new_dnd_df, dnd_best_tree_prediction_df), axis = 1)
new_dnd_df_with_prediction_best_tree
Out[80]:
STR DEX CON INT WIS CHA Prediction
0 9 17 8 13 16 15 68.111421
1 17 9 17 18 11 7 40.090909
In [81]:
leaf_number_for_new =  best_tree.apply(new_dnd_df)
leaf_number_for_new
Out[81]:
array([5, 3], dtype=int64)
In [82]:
leaf_number_for_new_df = pd.DataFrame(leaf_number_for_new, columns = ["leaf_number"])
leaf_number_for_new_df
Out[82]:
leaf_number
0 5
1 3
In [83]:
new_dnd_df_with_prediction_best_tree_leaf_number = pd.concat((new_dnd_df_with_prediction_best_tree, 
                                                             leaf_number_for_new_df), axis = 1)
new_dnd_df_with_prediction_best_tree_leaf_number
Out[83]:
STR DEX CON INT WIS CHA Prediction leaf_number
0 9 17 8 13 16 15 68.111421 5
1 17 9 17 18 11 7 40.090909 3

5.2.1 Range of predictions using the best tree¶

In [84]:
leaf_number = pd.DataFrame(best_tree.apply(train_X), columns=["leaf_number"], index = train_y.index)
In [85]:
leaf_number
Out[85]:
leaf_number
650 5
479 5
271 5
647 5
307 5
... ...
445 5
414 5
70 5
429 5
236 5

440 rows × 1 columns

In [86]:
leaf_df = pd.concat([leaf_number, train_y], axis = 1)
leaf_df
Out[86]:
leaf_number HP
650 5 117
479 5 120
271 5 72
647 5 117
307 5 100
... ... ...
445 5 32
414 5 70
70 5 80
429 5 72
236 5 91

440 rows × 2 columns

In [87]:
leaf_max_df = leaf_df.groupby(by = "leaf_number").max()
leaf_max_df
Out[87]:
HP
leaf_number
2 120
3 112
5 150
6 140
In [88]:
leaf_max_df = leaf_max_df.rename(columns = {"HP": "Max_HP"})
leaf_max_df
Out[88]:
Max_HP
leaf_number
2 120
3 112
5 150
6 140
In [89]:
leaf_min_df = leaf_df.groupby(by = "leaf_number").min()
leaf_min_df
Out[89]:
HP
leaf_number
2 9
3 6
5 7
6 6
In [90]:
leaf_min_df = leaf_min_df.rename(columns = {"HP": "Min_HP"})
leaf_min_df
Out[90]:
Min_HP
leaf_number
2 9
3 6
5 7
6 6
In [91]:
leaf_std_df = leaf_df.groupby(by = "leaf_number").std()
leaf_std_df
Out[91]:
HP
leaf_number
2 27.924217
3 29.395350
5 37.572854
6 36.517976
In [92]:
leaf_std_df = leaf_std_df.rename(columns = {"HP": "std_HP"})
leaf_std_df
Out[92]:
std_HP
leaf_number
2 27.924217
3 29.395350
5 37.572854
6 36.517976

Merge to get range of predictions

In [93]:
new_dnd_df_with_prediction_best_tree_leaf_number_range = pd.merge(
    pd.merge(
        pd.merge(new_dnd_df_with_prediction_best_tree_leaf_number,leaf_max_df, how = "inner", on = "leaf_number"),
        leaf_min_df, how = "inner", on = "leaf_number"),
leaf_std_df, how = "inner", on = "leaf_number")
new_dnd_df_with_prediction_best_tree_leaf_number_range
Out[93]:
STR DEX CON INT WIS CHA Prediction leaf_number Max_HP Min_HP std_HP
0 9 17 8 13 16 15 68.111421 5 150 7 37.572854
1 17 9 17 18 11 7 40.090909 3 112 6 29.395350

d20.jpeg

6. Random Forest¶

6.1 Fit the random forest¶

In [94]:
from sklearn.ensemble import RandomForestRegressor

rf = RandomForestRegressor(max_depth = 10, random_state = 666)
rf.fit(train_X, train_y)
Out[94]:
RandomForestRegressor(max_depth=10, random_state=666)
In [95]:
train_y_pred_rf = rf.predict(train_X)
train_y_pred_rf
Out[95]:
array([ 88.93415507, 103.4440601 ,  69.47537521, 102.25756247,
        90.6015338 ,  72.34565388,  55.08351732,  61.08039394,
        56.41217774,  55.12831922,  44.96499255,  50.06883454,
        49.50479846,  71.23535818,  62.30711203,  38.00342379,
        96.4850522 ,  88.06      ,  74.44247935,  39.28533981,
        65.46959216,  35.64007071,  45.80873105,  45.26427884,
        49.58661111,  81.29242571,  49.98355509,  85.12770696,
        56.05009987,  79.14899567,  82.05170197,  48.303423  ,
        43.96298786,  63.78496467,  82.41620807,  70.45749862,
        47.52008668, 115.36753846,  47.96596143,  32.5633898 ,
        75.86841179,  59.54143651,  77.72521177,  51.08989052,
        21.19936315,  77.51343015,  62.72026195,  78.9389704 ,
        72.71332005,  77.33110815,  56.42808742,  89.85577218,
        46.64963591,  85.13207896,  54.96862511,  56.87054614,
        43.91997655,  81.48412106,  72.74496673,  86.26800778,
        41.57996127,  99.21362527, 108.10668571,  57.29779173,
       111.28059477,  53.60430871,  82.44322777,  82.40620381,
        77.08339146,  91.56549427,  63.02749817,  54.35300364,
        82.2662034 ,  97.93201389,  73.0482379 ,  75.82030556,
        77.8802475 ,  72.55873069,  37.73159091,  71.44797238,
        54.50111558,  28.51136153,  81.70698526,  41.82350859,
        54.03108768,  45.3973026 ,  45.78339763,  61.36807   ,
        30.11053968,  82.13027778,  61.15893505, 101.57288501,
        67.48741823,  26.93542019,  88.02408402,  37.45897826,
        85.53040787,  84.95445402,  47.64463051,  67.75539365,
        37.25539184,  91.29050568,  48.96801172,  78.61864337,
        33.39780771,  77.58149604, 118.14853662,  55.76287724,
        68.37750732,  69.79108469,  63.83683632,  66.15682478,
        87.99372067,  54.63183916,  60.71442133,  33.09455779,
        30.96605128,  60.4653631 ,  65.59826143,  81.64426887,
        60.19963427,  45.11525176,  49.20249957,  74.16530098,
        73.19454255,  39.59228974,  54.49690542,  95.66990974,
        35.73416306,  39.08250008,  79.12853771, 112.28103497,
        68.98634541,  80.03228299,  67.99279785,  82.48577546,
        85.0130968 ,  70.083967  ,  74.94647463,  72.72633574,
        43.42748338,  82.81276538,  45.62634505, 106.61943651,
       102.58836559,  80.38613257,  57.37117338,  33.88748496,
       108.94866082,  53.35486688,  56.18110066,  34.60665282,
        78.43217183,  64.39763654,  45.69437607,  68.45122386,
        52.6406046 ,  32.45174725,  48.9233243 ,  29.83394958,
        35.42764936, 109.65492857,  51.11937884,  69.64900305,
        89.00004645,  58.59010944, 107.6301434 ,  63.54540565,
        52.22194841,  71.08996857,  67.97638999,  27.7934586 ,
        80.31229295,  71.33814287,  77.0330639 ,  69.61092666,
        48.67007949,  85.91442843,  41.74883263,  75.18138508,
        77.07730257,  53.18972691,  66.60794683,  65.06615015,
        62.96954681,  32.4638915 ,  25.73599206,  69.36473102,
        80.72062339,  69.3157973 ,  79.51272729,  50.00133222,
        53.77316321,  52.01618708,  51.0263894 ,  49.63979731,
        79.75040143,  42.25055932,  87.69675137,  64.09286381,
        72.61794547,  92.84779201,  98.12629072,  86.48315375,
        90.06800595,  57.20494978,  81.81610457,  51.62287151,
        50.15992208,  32.62498061,  92.94263059,  68.53155923,
        66.82723705,  70.01569548,  51.41607791,  62.33188215,
        71.4641899 ,  43.85667954, 105.44633669,  57.67718836,
        77.16578731,  74.30349661,  53.70944274,  71.63464022,
        75.0198326 ,  69.05771916,  88.68029785,  44.97268484,
       107.37958222,  23.36269481,  49.00191044,  58.6598738 ,
        78.08707725,  35.00298126,  84.92576119,  76.44292787,
        94.72199013,  94.11666398,  59.28519292,  60.45311085,
        65.7338112 ,  95.74016892,  58.27571061,  77.21380488,
        67.20433333,  49.15215437,  54.62181655,  55.11792308,
        34.05189765,  64.52346018,  56.13229827,  81.13058491,
       113.66492836,  76.42701323,  88.65096755,  57.28039899,
        38.63262642,  48.68805405,  59.04782569, 102.37085767,
        31.38396037,  52.23307442,  53.99473035,  80.48299547,
        52.97368403,  68.96615525,  45.42954444,  61.45521545,
        46.90742883,  53.06130937,  47.32722966,  76.63714672,
        53.59161405,  63.17177048,  87.5244256 ,  71.14728393,
        89.9188027 ,  68.04064219,  81.02650193, 108.56730092,
        95.20197455,  40.77743978,  60.0310126 ,  33.02766208,
        75.63861591,  49.40175669,  33.68347467,  67.28065404,
        52.59024541,  90.46681485,  96.96412321,  55.5304823 ,
        67.5809228 , 101.79697451,  43.97335836,  99.45552434,
        69.62235006,  90.65557188, 109.08199052,  63.41604082,
        38.49531318,  72.52384339,  90.26417945,  42.18824789,
        48.05441438,  47.9575858 ,  57.17206151,  30.2311224 ,
        71.52939374,  62.88909123,  68.43390846,  43.77445167,
        95.76515043,  74.0386496 ,  70.54859477,  55.06728316,
        59.3325523 ,  97.05169676,  80.23437332,  99.06046429,
        66.32816194,  89.5258736 ,  61.41168678,  98.17453733,
        76.69305752,  60.18673817,  54.82025641,  57.32873945,
        76.85930604,  45.96900252,  58.76426003,  55.51669719,
        48.82278961,  43.78012896,  58.54019432,  72.18893748,
        32.63400242,  61.5212612 ,  72.75890845,  57.61661024,
        65.39866234,  47.41540043,  77.30090306,  50.89748999,
        61.50746927,  79.49715321,  41.20356734,  56.03800925,
        94.15462125,  63.20338365,  95.28639722,  75.81232017,
        32.09039495,  53.32924083,  59.63109694,  87.66349323,
        71.51329672,  42.87189277,  48.26400166,  60.3417937 ,
        42.20664573,  43.00066797,  55.18067154,  76.27390707,
        78.82137727,  85.451     ,  69.72150089,  55.34246911,
        80.58014985,  86.81124055,  83.98661409,  74.85431944,
        42.61049311,  38.15734865,  85.9956654 ,  43.26669312,
        78.97732552,  36.26190476,  41.72199711,  53.07704693,
        34.18176444,  53.47318892,  48.90472817,  58.57496796,
        92.79668924,  78.4475137 ,  53.60581926,  32.80670145,
        49.42504221,  57.3471786 ,  75.856811  ,  35.04339654,
        52.5227308 ,  34.48859646,  59.79505366,  64.33654553,
        54.4064929 ,  24.49359101,  57.92275495,  57.89097697,
       100.67209492,  88.89994674,  39.90751776,  58.73859727,
        76.14315942,  62.97781913,  71.07756578,  97.56604108,
        99.52729537,  44.67429365,  88.03345238,  68.59522164,
        98.84517573,  60.62938004,  87.14872698,  47.58679546,
        72.82187091,  65.94254085,  89.09710914,  61.96470938,
        69.36748378,  60.7044839 ,  81.60740097,  71.04032612,
        54.72223377,  32.06326537,  67.85120587,  68.54966934,
        90.59010516,  53.08508131, 114.62275726,  81.56617735,
       116.21760762, 100.6281484 , 107.36454736,  54.93407498,
        61.55746122,  74.83490502,  67.25198763,  75.9885209 ])
In [96]:
mse_rf_train = sklearn.metrics.mean_squared_error(train_y, train_y_pred_rf)
mse_rf_train
Out[96]:
356.848877962145
In [97]:
# import math

rmse_rf_train = math.sqrt(mse_rf_train)
rmse_rf_train
Out[97]:
18.89044409118391
In [98]:
# If using the dmba package, install it first:

# pip install dmba
# or
# conda install -c conda-forge dmba

# import dmba
# from dmba import regressionSummary

regressionSummary(train_y, train_y_pred_rf)
Regression statistics

                      Mean Error (ME) : -0.3851
       Root Mean Squared Error (RMSE) : 18.8904
            Mean Absolute Error (MAE) : 15.6621
          Mean Percentage Error (MPE) : -38.5278
Mean Absolute Percentage Error (MAPE) : 52.8393
In [99]:
valid_y_pred_rf = rf.predict(valid_X)
valid_y_pred_rf
Out[99]:
array([85.99613623, 72.81197421, 57.14756602, 65.48415094, 64.23185644,
       73.88297094, 75.63706658, 65.88536111, 77.38755932, 53.3611746 ,
       84.30540558, 81.38338597, 69.54557467, 76.4453801 , 77.30100974,
       60.7457147 , 60.28698095, 54.54960764, 58.39220608, 79.51837662,
       57.45235255, 44.49966493, 80.43081232, 80.9672802 , 79.12526316,
       61.88250468, 68.48486885, 55.05741665, 54.63041239, 66.02297654,
       74.05684091, 61.26967588, 63.94936452, 68.44916792, 79.89559722,
       47.99857592, 77.75612193, 51.24598032, 91.96242819, 95.47827337,
       72.99748216, 61.48166667, 74.24409524, 71.74566751, 65.74643356,
       72.12616881, 59.34396605, 55.31680735, 75.01714982, 67.35775287,
       62.03485455, 69.73391667, 60.42239718, 49.83764828, 78.65800622,
       72.81000229, 84.1797256 , 73.57869858, 69.42034164, 54.17249117,
       80.91616771, 31.0329006 , 58.56498167, 86.61381342, 53.14061718,
       46.71116392, 51.25840051, 57.74422334, 70.76289968, 68.26478858,
       53.12914791, 67.50639943, 58.15775315, 86.5268843 , 54.6770392 ,
       67.69249941, 59.79266884, 74.1011205 , 80.20981113, 47.86735296,
       57.1792636 , 65.4681474 , 71.61567756, 64.21365135, 52.96507108,
       49.00585282, 60.93468812, 66.42856805, 72.62375   , 61.39575291,
       52.3028228 , 70.51631313, 61.68713413, 82.59035772, 62.97275458,
       65.08313092, 62.98707668, 54.98747072, 74.79348167, 77.76127609,
       53.28477675, 65.1869692 , 68.7451453 , 70.07741324, 62.38078571,
       63.3365805 , 73.22574008, 64.41808607, 59.03357159, 56.65995647,
       48.07595775, 66.59146202, 66.1300717 , 64.77757169, 58.36584383,
       63.15126245, 53.69026718, 61.60511499, 68.70015011, 68.68720783,
       61.64060036, 70.09213449, 73.00846138, 63.50531364, 54.47829147,
       51.04458733, 75.76759166, 50.68129187, 85.91485798, 45.009602  ,
       74.03758542, 60.60588345, 66.72822863, 50.40984565, 50.91525236,
       79.34774785, 74.27093714, 82.57503454, 84.23115898, 58.34441943,
       73.50657518, 76.91086606, 94.43031653, 51.995     , 66.51659984,
       59.95371429, 48.8428732 , 74.67547684, 54.12539596, 62.59740704,
       38.82594481, 66.68127848, 59.69954936, 62.24238866, 51.24587194,
       67.46654537, 50.08880019, 78.89875988, 70.45242092, 76.13153563,
       47.54842012, 71.55815842, 77.32311666, 51.02301605, 45.95923124,
       54.17618857, 72.52685606, 59.74554544, 61.51544039, 61.22026958,
       74.06011706, 85.25158929, 66.50520371, 71.80522197, 46.78741758,
       75.15755218, 64.5374446 , 56.70808703, 83.22767766, 59.33948239,
       48.14300962, 80.07204362, 65.39221338, 61.34071397, 50.61837513,
       76.18111707, 48.45899522, 82.61167027, 68.64237587, 69.85818255,
       66.21509905, 82.97757268, 45.1973547 , 61.34981061, 69.38518179,
       91.70802321, 70.23158458, 27.83719221, 59.99675521, 69.57862161,
       70.99210766, 60.44286601, 65.59296299, 70.98130797, 65.47639993,
       68.08674206, 68.73752076, 67.61220989, 74.31800031, 64.41745114,
       67.19080289, 78.22315147, 76.38173478, 56.18366362, 52.71594479,
       56.61547222, 73.43090887, 91.90809606, 59.87043638, 57.65063381,
       51.46242433, 78.90791805, 60.54316234, 51.72708547, 75.8761541 ,
       71.22114286, 43.23641919, 65.66313809, 94.86817947, 67.21091462,
       52.21235575, 72.71997829, 67.16296753, 60.14175456, 68.69688344,
       49.47745978, 61.98980765, 61.37528073, 55.09736371, 66.38018356,
       56.80601079, 64.55983748, 79.77329875, 49.94612653, 56.46429733,
       70.81355007, 73.30237874, 69.29287879, 51.04525876, 57.27526319,
       55.77392348, 70.01296176, 74.8949364 , 83.63937886, 62.96207777,
       53.7513566 , 62.29764406, 69.50477559, 43.34733164, 80.23455717,
       61.96857608, 69.37674323, 37.29      , 65.15223851, 72.95181574,
       50.16673751, 75.97699246, 70.30796224, 64.95026118, 68.25145774,
       61.41429168, 65.10816719, 64.88187943, 68.41595238, 57.63969008,
       65.51860532, 88.77396387, 70.2144709 , 60.30475757, 55.3897939 ,
       61.63742355, 94.91610892, 72.07392157, 60.4653631 , 55.745187  ,
       73.56503645, 42.8781292 , 69.60228571, 84.44706084, 48.06643892,
       71.47878175, 57.75287869, 66.18710983, 83.19506918])
In [100]:
mse_rf_valid = sklearn.metrics.mean_squared_error(valid_y, valid_y_pred_rf)
mse_rf_valid
Out[100]:
1426.6006551155326
In [101]:
# import math

rmse_rf_valid = math.sqrt(mse_rf_valid)
rmse_rf_valid
Out[101]:
37.7703674209761
In [102]:
# If using the dmba package, install it first:

# pip install dmba
# or
# conda install -c conda-forge dmba

# import dmba
# from dmba import regressionSummary

regressionSummary(valid_y, valid_y_pred_rf)
Regression statistics

                      Mean Error (ME) : 3.4807
       Root Mean Squared Error (RMSE) : 37.7704
            Mean Absolute Error (MAE) : 31.1829
          Mean Percentage Error (MPE) : -48.5668
Mean Absolute Percentage Error (MAPE) : 81.0687

Variable importance

In [103]:
var_importance = rf.feature_importances_
var_importance
Out[103]:
array([0.17943126, 0.1634644 , 0.16296725, 0.16134494, 0.16488891,
       0.16790324])
In [104]:
std = np.std([tree.feature_importances_ for tree in rf.estimators_], axis = 0)
std
Out[104]:
array([0.04659177, 0.03802737, 0.0363087 , 0.03966035, 0.03344671,
       0.03432209])
In [105]:
var_importance_df = pd.DataFrame({"variable": train_X.columns, "importance": var_importance, "std": std})
var_importance_df
Out[105]:
variable importance std
0 STR 0.179431 0.046592
1 DEX 0.163464 0.038027
2 CON 0.162967 0.036309
3 INT 0.161345 0.039660
4 WIS 0.164889 0.033447
5 CHA 0.167903 0.034322
In [106]:
var_importance_df_sorted = var_importance_df.sort_values("importance")
var_importance_df_sorted
Out[106]:
variable importance std
3 INT 0.161345 0.039660
2 CON 0.162967 0.036309
1 DEX 0.163464 0.038027
4 WIS 0.164889 0.033447
5 CHA 0.167903 0.034322
0 STR 0.179431 0.046592
In [107]:
import matplotlib.pyplot as plt

var_importance_plot = var_importance_df_sorted.plot(kind = "barh", xerr = "std", x = "variable", legend = False)
var_importance_plot.set_ylabel("")
var_importance_plot.set_xlabel("Importance")
plt.show()

6.2 Predict using the random forest¶

In [108]:
new_records_dnd_rf_pred = rf.predict(new_dnd_df)
new_records_dnd_rf_pred
Out[108]:
array([64.32568489, 70.02311429])
In [109]:
dnd_rf_prediction_df = pd.DataFrame(new_records_dnd_rf_pred,
                                 columns = ["Prediction"])
dnd_rf_prediction_df
Out[109]:
Prediction
0 64.325685
1 70.023114

Merge with new data

In [110]:
new_dnd_df_with_prediction_rf = pd.concat((new_dnd_df, dnd_rf_prediction_df), axis = 1)
new_dnd_df_with_prediction_rf

# to export
# new_dnd_df_with_prediction_rf.to_csv("whatever_name.csv")
Out[110]:
STR DEX CON INT WIS CHA Prediction
0 9 17 8 13 16 15 64.325685
1 17 9 17 18 11 7 70.023114

d20.jpeg