Dungeons and Dragons¶

dnd-dragons.jpg

Data for demo

Back to the spell book

1. Load Data¶

1.1 Libraries¶

In [1]:
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeRegressor

1.2 Data¶

In [2]:
dnd_df = pd.read_csv("super_heroes_dnd_v3a.csv")
dnd_df.head()
Out[2]:
ID Name Gender Race Height Publisher Alignment Weight STR DEX CON INT WIS CHA Level HP
0 A001 A-Bomb Male Human 203.0 Marvel Comics good 441.0 18 11 17 12 13 11 1 7
1 A002 Abe Sapien Male Icthyo Sapien 191.0 Dark Horse Comics good 65.0 16 17 10 13 15 11 8 72
2 A004 Abomination Male Human / Radiation 203.0 Marvel Comics bad 441.0 13 14 13 10 18 15 15 135
3 A009 Agent 13 Female NaN 173.0 Marvel Comics good 61.0 15 18 16 16 17 10 14 140
4 A015 Alex Mercer Male Human NaN Wildstorm bad NaN 14 17 13 12 10 11 9 72
In [3]:
dnd_df.dtypes
Out[3]:
ID            object
Name          object
Gender        object
Race          object
Height       float64
Publisher     object
Alignment     object
Weight       float64
STR            int64
DEX            int64
CON            int64
INT            int64
WIS            int64
CHA            int64
Level          int64
HP             int64
dtype: object
In [4]:
pd.DataFrame(dnd_df.columns.values, columns = ["variables"])
Out[4]:
variables
0 ID
1 Name
2 Gender
3 Race
4 Height
5 Publisher
6 Alignment
7 Weight
8 STR
9 DEX
10 CON
11 INT
12 WIS
13 CHA
14 Level
15 HP

It's a good idea to get a sense of the target variable

In [5]:
dnd_df["HP"].describe()
Out[5]:
count    734.000000
mean      66.885559
std       36.653877
min        6.000000
25%       36.000000
50%       63.000000
75%       91.000000
max      150.000000
Name: HP, dtype: float64
In [6]:
dnd_df_2 = dnd_df.iloc[:, np.r_[8:14, 15]]
dnd_df_2

# Alternatively, use:
# dnd_df.iloc[:, list(range(8,14)) + [15]]
# Note the end range

# Or just use:
# dnd_df.iloc[:, [8, 9, 10, 11, 12, 13, 15]]

# Or use the variable name range
# dnd_df.loc[:, "STR":"HP"]

# Or specify the variable names
# dnd_df.loc[:, ["STR", "DEX", "CON", "INT", "WIS", "CHA", "HP"]]
Out[6]:
STR DEX CON INT WIS CHA HP
0 18 11 17 12 13 11 7
1 16 17 10 13 15 11 72
2 13 14 13 10 18 15 135
3 15 18 16 16 17 10 140
4 14 17 13 12 10 11 72
... ... ... ... ... ... ... ...
729 8 14 17 13 14 15 64
730 17 12 11 11 14 10 56
731 18 10 14 17 10 10 49
732 11 11 10 12 15 16 36
733 16 12 18 15 15 16 81

734 rows × 7 columns

Training-Validation Split¶

In [7]:
import sklearn
from sklearn.model_selection import train_test_split
In [8]:
predictors = ["STR", "DEX", "CON", "INT", "WIS", "CHA"]
outcome = "HP"
In [9]:
X = dnd_df_2.drop(columns = ["HP"])
y = dnd_df_2["HP"]
In [10]:
train_X, valid_X, train_y, valid_y = train_test_split(X, y, test_size = 0.4, random_state = 666)
In [11]:
train_X.head()
Out[11]:
STR DEX CON INT WIS CHA
650 17 14 16 16 15 17
479 8 18 16 10 14 17
271 9 12 17 10 15 17
647 9 18 16 10 17 13
307 12 16 14 18 15 13
In [12]:
len(train_X)
Out[12]:
440
In [13]:
train_y.head()
Out[13]:
650    117
479    120
271     72
647    117
307    100
Name: HP, dtype: int64
In [14]:
len(train_y)
Out[14]:
440
In [15]:
valid_X.head()
Out[15]:
STR DEX CON INT WIS CHA
389 10 16 15 13 11 10
131 18 10 12 10 16 18
657 10 11 12 11 18 14
421 16 13 11 16 13 11
160 12 16 17 18 11 15
In [16]:
len(valid_X)
Out[16]:
294
In [17]:
valid_y.head()
Out[17]:
389    45
131    42
657    63
421    64
160    54
Name: HP, dtype: int64
In [18]:
len(valid_y)
Out[18]:
294

3. Decision Tree¶

3.1 Large tree¶

In [19]:
full_tree = DecisionTreeRegressor(random_state = 666)
full_tree
Out[19]:
DecisionTreeRegressor(random_state=666)
In [20]:
full_tree_fit = full_tree.fit(train_X, train_y)

Plot the tree

In [21]:
from sklearn import tree

Export the top levels for illustration using max_depth. Export the whole tree if max_depth is excluded.

In [22]:
text_representation = tree.export_text(full_tree, max_depth = 5)
print(text_representation)
|--- feature_1 <= 10.50
|   |--- feature_3 <= 14.50
|   |   |--- feature_3 <= 10.50
|   |   |   |--- feature_5 <= 16.00
|   |   |   |   |--- feature_2 <= 13.50
|   |   |   |   |   |--- value: [18.00]
|   |   |   |   |--- feature_2 >  13.50
|   |   |   |   |   |--- feature_0 <= 11.50
|   |   |   |   |   |   |--- value: [48.00]
|   |   |   |   |   |--- feature_0 >  11.50
|   |   |   |   |   |   |--- value: [50.00]
|   |   |   |--- feature_5 >  16.00
|   |   |   |   |--- value: [9.00]
|   |   |--- feature_3 >  10.50
|   |   |   |--- feature_2 <= 13.50
|   |   |   |   |--- feature_0 <= 17.50
|   |   |   |   |   |--- feature_0 <= 12.00
|   |   |   |   |   |   |--- value: [40.00]
|   |   |   |   |   |--- feature_0 >  12.00
|   |   |   |   |   |   |--- truncated branch of depth 4
|   |   |   |   |--- feature_0 >  17.50
|   |   |   |   |   |--- value: [90.00]
|   |   |   |--- feature_2 >  13.50
|   |   |   |   |--- feature_0 <= 10.50
|   |   |   |   |   |--- feature_3 <= 13.50
|   |   |   |   |   |   |--- value: [56.00]
|   |   |   |   |   |--- feature_3 >  13.50
|   |   |   |   |   |   |--- value: [50.00]
|   |   |   |   |--- feature_0 >  10.50
|   |   |   |   |   |--- feature_2 <= 17.50
|   |   |   |   |   |   |--- truncated branch of depth 4
|   |   |   |   |   |--- feature_2 >  17.50
|   |   |   |   |   |   |--- value: [120.00]
|   |--- feature_3 >  14.50
|   |   |--- feature_5 <= 10.50
|   |   |   |--- feature_0 <= 17.50
|   |   |   |   |--- feature_4 <= 16.00
|   |   |   |   |   |--- value: [112.00]
|   |   |   |   |--- feature_4 >  16.00
|   |   |   |   |   |--- value: [84.00]
|   |   |   |--- feature_0 >  17.50
|   |   |   |   |--- value: [49.00]
|   |   |--- feature_5 >  10.50
|   |   |   |--- feature_4 <= 11.50
|   |   |   |   |--- feature_3 <= 16.50
|   |   |   |   |   |--- feature_4 <= 10.50
|   |   |   |   |   |   |--- value: [6.00]
|   |   |   |   |   |--- feature_4 >  10.50
|   |   |   |   |   |   |--- truncated branch of depth 2
|   |   |   |   |--- feature_3 >  16.50
|   |   |   |   |   |--- feature_4 <= 10.50
|   |   |   |   |   |   |--- value: [54.00]
|   |   |   |   |   |--- feature_4 >  10.50
|   |   |   |   |   |   |--- value: [20.00]
|   |   |   |--- feature_4 >  11.50
|   |   |   |   |--- feature_4 <= 17.50
|   |   |   |   |   |--- feature_5 <= 12.50
|   |   |   |   |   |   |--- truncated branch of depth 2
|   |   |   |   |   |--- feature_5 >  12.50
|   |   |   |   |   |   |--- truncated branch of depth 6
|   |   |   |   |--- feature_4 >  17.50
|   |   |   |   |   |--- feature_3 <= 17.50
|   |   |   |   |   |   |--- truncated branch of depth 2
|   |   |   |   |   |--- feature_3 >  17.50
|   |   |   |   |   |   |--- value: [50.00]
|--- feature_1 >  10.50
|   |--- feature_4 <= 17.50
|   |   |--- feature_2 <= 17.50
|   |   |   |--- feature_5 <= 10.50
|   |   |   |   |--- feature_2 <= 12.50
|   |   |   |   |   |--- feature_4 <= 11.50
|   |   |   |   |   |   |--- truncated branch of depth 2
|   |   |   |   |   |--- feature_4 >  11.50
|   |   |   |   |   |   |--- truncated branch of depth 5
|   |   |   |   |--- feature_2 >  12.50
|   |   |   |   |   |--- feature_2 <= 16.50
|   |   |   |   |   |   |--- truncated branch of depth 7
|   |   |   |   |   |--- feature_2 >  16.50
|   |   |   |   |   |   |--- truncated branch of depth 5
|   |   |   |--- feature_5 >  10.50
|   |   |   |   |--- feature_5 <= 17.50
|   |   |   |   |   |--- feature_3 <= 10.50
|   |   |   |   |   |   |--- truncated branch of depth 10
|   |   |   |   |   |--- feature_3 >  10.50
|   |   |   |   |   |   |--- truncated branch of depth 13
|   |   |   |   |--- feature_5 >  17.50
|   |   |   |   |   |--- feature_2 <= 15.50
|   |   |   |   |   |   |--- truncated branch of depth 10
|   |   |   |   |   |--- feature_2 >  15.50
|   |   |   |   |   |   |--- truncated branch of depth 8
|   |   |--- feature_2 >  17.50
|   |   |   |--- feature_1 <= 15.50
|   |   |   |   |--- feature_4 <= 12.50
|   |   |   |   |   |--- feature_0 <= 16.50
|   |   |   |   |   |   |--- truncated branch of depth 4
|   |   |   |   |   |--- feature_0 >  16.50
|   |   |   |   |   |   |--- truncated branch of depth 3
|   |   |   |   |--- feature_4 >  12.50
|   |   |   |   |   |--- feature_0 <= 17.50
|   |   |   |   |   |   |--- truncated branch of depth 5
|   |   |   |   |   |--- feature_0 >  17.50
|   |   |   |   |   |   |--- value: [8.00]
|   |   |   |--- feature_1 >  15.50
|   |   |   |   |--- feature_4 <= 12.50
|   |   |   |   |   |--- feature_3 <= 11.50
|   |   |   |   |   |   |--- truncated branch of depth 2
|   |   |   |   |   |--- feature_3 >  11.50
|   |   |   |   |   |   |--- truncated branch of depth 3
|   |   |   |   |--- feature_4 >  12.50
|   |   |   |   |   |--- feature_1 <= 16.50
|   |   |   |   |   |   |--- truncated branch of depth 4
|   |   |   |   |   |--- feature_1 >  16.50
|   |   |   |   |   |   |--- truncated branch of depth 4
|   |--- feature_4 >  17.50
|   |   |--- feature_0 <= 14.50
|   |   |   |--- feature_3 <= 16.50
|   |   |   |   |--- feature_3 <= 13.50
|   |   |   |   |   |--- feature_1 <= 14.50
|   |   |   |   |   |   |--- truncated branch of depth 6
|   |   |   |   |   |--- feature_1 >  14.50
|   |   |   |   |   |   |--- truncated branch of depth 3
|   |   |   |   |--- feature_3 >  13.50
|   |   |   |   |   |--- feature_1 <= 11.50
|   |   |   |   |   |   |--- truncated branch of depth 2
|   |   |   |   |   |--- feature_1 >  11.50
|   |   |   |   |   |   |--- truncated branch of depth 5
|   |   |   |--- feature_3 >  16.50
|   |   |   |   |--- feature_2 <= 17.00
|   |   |   |   |   |--- feature_3 <= 17.50
|   |   |   |   |   |   |--- value: [72.00]
|   |   |   |   |   |--- feature_3 >  17.50
|   |   |   |   |   |   |--- truncated branch of depth 3
|   |   |   |   |--- feature_2 >  17.00
|   |   |   |   |   |--- value: [117.00]
|   |   |--- feature_0 >  14.50
|   |   |   |--- feature_0 <= 15.50
|   |   |   |   |--- feature_1 <= 14.50
|   |   |   |   |   |--- feature_5 <= 16.00
|   |   |   |   |   |   |--- value: [9.00]
|   |   |   |   |   |--- feature_5 >  16.00
|   |   |   |   |   |   |--- value: [6.00]
|   |   |   |   |--- feature_1 >  14.50
|   |   |   |   |   |--- value: [28.00]
|   |   |   |--- feature_0 >  15.50
|   |   |   |   |--- feature_5 <= 17.50
|   |   |   |   |   |--- feature_2 <= 12.50
|   |   |   |   |   |   |--- truncated branch of depth 3
|   |   |   |   |   |--- feature_2 >  12.50
|   |   |   |   |   |   |--- truncated branch of depth 3
|   |   |   |   |--- feature_5 >  17.50
|   |   |   |   |   |--- value: [112.00]

Plot the top 5 levels for illustration using max_depth. Plot the whole tree if max_depth is excluded.

In [23]:
tree.plot_tree(full_tree, feature_names = train_X.columns, max_depth = 5)
Out[23]:
[Text(0.45454545454545453, 0.9285714285714286, 'DEX <= 10.5\nsquared_error = 1382.015\nsamples = 440\nvalue = 65.552'),
 Text(0.1690340909090909, 0.7857142857142857, 'INT <= 14.5\nsquared_error = 933.256\nsamples = 43\nvalue = 52.0'),
 Text(0.07670454545454546, 0.6428571428571429, 'INT <= 10.5\nsquared_error = 742.63\nsamples = 21\nvalue = 64.476'),
 Text(0.03409090909090909, 0.5, 'CHA <= 16.0\nsquared_error = 325.688\nsamples = 4\nvalue = 31.25'),
 Text(0.022727272727272728, 0.35714285714285715, 'CON <= 13.5\nsquared_error = 214.222\nsamples = 3\nvalue = 38.667'),
 Text(0.011363636363636364, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 18.0'),
 Text(0.03409090909090909, 0.21428571428571427, 'STR <= 11.5\nsquared_error = 1.0\nsamples = 2\nvalue = 49.0'),
 Text(0.022727272727272728, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.045454545454545456, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.045454545454545456, 0.35714285714285715, 'squared_error = 0.0\nsamples = 1\nvalue = 9.0'),
 Text(0.11931818181818182, 0.5, 'CON <= 13.5\nsquared_error = 519.855\nsamples = 17\nvalue = 72.294'),
 Text(0.09090909090909091, 0.35714285714285715, 'STR <= 17.5\nsquared_error = 187.484\nsamples = 8\nvalue = 58.375'),
 Text(0.07954545454545454, 0.21428571428571427, 'STR <= 12.0\nsquared_error = 50.98\nsamples = 7\nvalue = 53.857'),
 Text(0.06818181818181818, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.09090909090909091, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.10227272727272728, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 90.0'),
 Text(0.14772727272727273, 0.35714285714285715, 'STR <= 10.5\nsquared_error = 490.0\nsamples = 9\nvalue = 84.667'),
 Text(0.125, 0.21428571428571427, 'INT <= 13.5\nsquared_error = 9.0\nsamples = 2\nvalue = 53.0'),
 Text(0.11363636363636363, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.13636363636363635, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.17045454545454544, 0.21428571428571427, 'CON <= 17.5\nsquared_error = 259.061\nsamples = 7\nvalue = 93.714'),
 Text(0.1590909090909091, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.18181818181818182, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.26136363636363635, 0.6428571428571429, 'CHA <= 10.5\nsquared_error = 824.81\nsamples = 22\nvalue = 40.091'),
 Text(0.2159090909090909, 0.5, 'STR <= 17.5\nsquared_error = 664.222\nsamples = 3\nvalue = 81.667'),
 Text(0.20454545454545456, 0.35714285714285715, 'WIS <= 16.0\nsquared_error = 196.0\nsamples = 2\nvalue = 98.0'),
 Text(0.19318181818181818, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 112.0'),
 Text(0.2159090909090909, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 84.0'),
 Text(0.22727272727272727, 0.35714285714285715, 'squared_error = 0.0\nsamples = 1\nvalue = 49.0'),
 Text(0.3068181818181818, 0.5, 'WIS <= 11.5\nsquared_error = 534.144\nsamples = 19\nvalue = 33.526'),
 Text(0.26136363636363635, 0.35714285714285715, 'INT <= 16.5\nsquared_error = 319.04\nsamples = 5\nvalue = 19.6'),
 Text(0.23863636363636365, 0.21428571428571427, 'WIS <= 10.5\nsquared_error = 2.667\nsamples = 3\nvalue = 8.0'),
 Text(0.22727272727272727, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.25, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.2840909090909091, 0.21428571428571427, 'WIS <= 10.5\nsquared_error = 289.0\nsamples = 2\nvalue = 37.0'),
 Text(0.2727272727272727, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.29545454545454547, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.3522727272727273, 0.35714285714285715, 'WIS <= 17.5\nsquared_error = 516.964\nsamples = 14\nvalue = 38.5'),
 Text(0.32954545454545453, 0.21428571428571427, 'CHA <= 12.5\nsquared_error = 382.41\nsamples = 10\nvalue = 46.3'),
 Text(0.3181818181818182, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.3409090909090909, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.375, 0.21428571428571427, 'INT <= 17.5\nsquared_error = 321.0\nsamples = 4\nvalue = 19.0'),
 Text(0.36363636363636365, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.38636363636363635, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.7400568181818182, 0.7857142857142857, 'WIS <= 17.5\nsquared_error = 1408.574\nsamples = 397\nvalue = 67.02'),
 Text(0.5795454545454546, 0.6428571428571429, 'CON <= 17.5\nsquared_error = 1407.787\nsamples = 359\nvalue = 68.111'),
 Text(0.48863636363636365, 0.5, 'CHA <= 10.5\nsquared_error = 1409.398\nsamples = 323\nvalue = 69.372'),
 Text(0.4431818181818182, 0.35714285714285715, 'CON <= 12.5\nsquared_error = 1278.057\nsamples = 43\nvalue = 78.419'),
 Text(0.42045454545454547, 0.21428571428571427, 'WIS <= 11.5\nsquared_error = 1262.102\nsamples = 14\nvalue = 57.429'),
 Text(0.4090909090909091, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.4318181818181818, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.4659090909090909, 0.21428571428571427, 'CON <= 16.5\nsquared_error = 970.385\nsamples = 29\nvalue = 88.552'),
 Text(0.45454545454545453, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.4772727272727273, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.5340909090909091, 0.35714285714285715, 'CHA <= 17.5\nsquared_error = 1415.068\nsamples = 280\nvalue = 67.982'),
 Text(0.5113636363636364, 0.21428571428571427, 'INT <= 10.5\nsquared_error = 1402.723\nsamples = 241\nvalue = 66.593'),
 Text(0.5, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.5227272727272727, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.5568181818181818, 0.21428571428571427, 'CON <= 15.5\nsquared_error = 1405.784\nsamples = 39\nvalue = 76.564'),
 Text(0.5454545454545454, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.5681818181818182, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.6704545454545454, 0.5, 'DEX <= 15.5\nsquared_error = 1251.268\nsamples = 36\nvalue = 56.806'),
 Text(0.625, 0.35714285714285715, 'WIS <= 12.5\nsquared_error = 779.741\nsamples = 21\nvalue = 45.857'),
 Text(0.6022727272727273, 0.21428571428571427, 'STR <= 16.5\nsquared_error = 788.29\nsamples = 10\nvalue = 54.9'),
 Text(0.5909090909090909, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.6136363636363636, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.6477272727272727, 0.21428571428571427, 'STR <= 17.5\nsquared_error = 630.05\nsamples = 11\nvalue = 37.636'),
 Text(0.6363636363636364, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.6590909090909091, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.7159090909090909, 0.35714285714285715, 'WIS <= 12.5\nsquared_error = 1508.649\nsamples = 15\nvalue = 72.133'),
 Text(0.6931818181818182, 0.21428571428571427, 'INT <= 11.5\nsquared_error = 584.889\nsamples = 6\nvalue = 37.333'),
 Text(0.6818181818181818, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.7045454545454546, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.7386363636363636, 0.21428571428571427, 'DEX <= 16.5\nsquared_error = 778.889\nsamples = 9\nvalue = 95.333'),
 Text(0.7272727272727273, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.75, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.9005681818181818, 0.6428571428571429, 'STR <= 14.5\nsquared_error = 1298.469\nsamples = 38\nvalue = 56.711'),
 Text(0.8465909090909091, 0.5, 'INT <= 16.5\nsquared_error = 1403.386\nsamples = 26\nvalue = 61.808'),
 Text(0.8068181818181818, 0.35714285714285715, 'INT <= 13.5\nsquared_error = 1391.959\nsamples = 21\nvalue = 54.429'),
 Text(0.7840909090909091, 0.21428571428571427, 'DEX <= 14.5\nsquared_error = 1660.628\nsamples = 11\nvalue = 67.909'),
 Text(0.7727272727272727, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.7954545454545454, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.8295454545454546, 0.21428571428571427, 'DEX <= 11.5\nsquared_error = 676.64\nsamples = 10\nvalue = 39.6'),
 Text(0.8181818181818182, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.8409090909090909, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.8863636363636364, 0.35714285714285715, 'CON <= 17.0\nsquared_error = 262.16\nsamples = 5\nvalue = 92.8'),
 Text(0.875, 0.21428571428571427, 'INT <= 17.5\nsquared_error = 144.688\nsamples = 4\nvalue = 86.75'),
 Text(0.8636363636363636, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.8863636363636364, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.8977272727272727, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 117.0'),
 Text(0.9545454545454546, 0.5, 'STR <= 15.5\nsquared_error = 892.889\nsamples = 12\nvalue = 45.667'),
 Text(0.9318181818181818, 0.35714285714285715, 'DEX <= 14.5\nsquared_error = 76.5\nsamples = 4\nvalue = 13.0'),
 Text(0.9204545454545454, 0.21428571428571427, 'CHA <= 16.0\nsquared_error = 2.0\nsamples = 3\nvalue = 8.0'),
 Text(0.9090909090909091, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.9318181818181818, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.9431818181818182, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 28.0'),
 Text(0.9772727272727273, 0.35714285714285715, 'CHA <= 17.5\nsquared_error = 500.75\nsamples = 8\nvalue = 62.0'),
 Text(0.9659090909090909, 0.21428571428571427, 'CON <= 12.5\nsquared_error = 164.122\nsamples = 7\nvalue = 54.857'),
 Text(0.9545454545454546, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.9772727272727273, 0.07142857142857142, '\n  (...)  \n'),
 Text(0.9886363636363636, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 112.0')]

Export tree and convert to a picture file.

In [24]:
from sklearn.tree import export_graphviz
In [25]:
dot_data = export_graphviz(full_tree, out_file='full_tree.dot', feature_names = train_X.columns)

Not very useful.

full_tree.png

3.2 Small Tree¶

In [26]:
small_tree = DecisionTreeRegressor(random_state = 666, max_depth = 3, min_samples_split = 25)
small_tree
Out[26]:
DecisionTreeRegressor(max_depth=3, min_samples_split=25, random_state=666)
In [27]:
small_tree_fit = small_tree.fit(train_X, train_y)

Plot the tree

In [28]:
# For illustration:
# from sklearn import tree

Export the top levels for illustration using max_depth. Export the whole tree if max_depth is excluded.

In [29]:
text_representation_2 = tree.export_text(small_tree)
print(text_representation_2)
|--- feature_1 <= 10.50
|   |--- feature_3 <= 14.50
|   |   |--- value: [64.48]
|   |--- feature_3 >  14.50
|   |   |--- value: [40.09]
|--- feature_1 >  10.50
|   |--- feature_4 <= 17.50
|   |   |--- feature_2 <= 17.50
|   |   |   |--- value: [69.37]
|   |   |--- feature_2 >  17.50
|   |   |   |--- value: [56.81]
|   |--- feature_4 >  17.50
|   |   |--- feature_0 <= 14.50
|   |   |   |--- value: [61.81]
|   |   |--- feature_0 >  14.50
|   |   |   |--- value: [45.67]

Plot the top 5 levels for illustration using max_depth. Plot the whole tree if max_depth is excluded.

In [30]:
tree.plot_tree(small_tree, feature_names = train_X.columns)
Out[30]:
[Text(0.4090909090909091, 0.875, 'DEX <= 10.5\nsquared_error = 1382.015\nsamples = 440\nvalue = 65.552'),
 Text(0.18181818181818182, 0.625, 'INT <= 14.5\nsquared_error = 933.256\nsamples = 43\nvalue = 52.0'),
 Text(0.09090909090909091, 0.375, 'squared_error = 742.63\nsamples = 21\nvalue = 64.476'),
 Text(0.2727272727272727, 0.375, 'squared_error = 824.81\nsamples = 22\nvalue = 40.091'),
 Text(0.6363636363636364, 0.625, 'WIS <= 17.5\nsquared_error = 1408.574\nsamples = 397\nvalue = 67.02'),
 Text(0.45454545454545453, 0.375, 'CON <= 17.5\nsquared_error = 1407.787\nsamples = 359\nvalue = 68.111'),
 Text(0.36363636363636365, 0.125, 'squared_error = 1409.398\nsamples = 323\nvalue = 69.372'),
 Text(0.5454545454545454, 0.125, 'squared_error = 1251.268\nsamples = 36\nvalue = 56.806'),
 Text(0.8181818181818182, 0.375, 'STR <= 14.5\nsquared_error = 1298.469\nsamples = 38\nvalue = 56.711'),
 Text(0.7272727272727273, 0.125, 'squared_error = 1403.386\nsamples = 26\nvalue = 61.808'),
 Text(0.9090909090909091, 0.125, 'squared_error = 892.889\nsamples = 12\nvalue = 45.667')]

Export tree and convert to a picture file.

In [31]:
# For illustration
# from sklearn.tree import export_graphviz
In [32]:
dot_data_2 = export_graphviz(small_tree, out_file='small_tree.dot', feature_names = train_X.columns)

Much better.

small_tree.png

3.3 Model Evaluation¶

3.3.1 Predictions Using the Full Tree¶

On the training set

In [33]:
train_y_pred_full = full_tree.predict(train_X)
train_y_pred_full
Out[33]:
array([117., 120.,  72., 117., 100.,  90.,  49.,  60.,  42.,  50.,  20.,
        32.,  36.,  80.,  36.,  30., 117.,  99.,  90.,  12.,  42.,  28.,
        40.,  24.,  50.,  98.,  32.,  90.,  50.,  88., 130.,  35.,  18.,
        70.,  99.,  66.,  20., 150.,  56.,   8.,  98.,  54.,  81.,  35.,
         6.,  81.,  56.,  90.,  88.,  70.,  60., 104.,  24.,  81.,  54.,
        60.,  30.,  80.,  84.,  98.,  12., 140., 135.,  56., 135.,  30.,
       117.,  99.,  81., 105.,  42.,  48., 100., 110.,  77.,  84.,  84.,
       104.,  16.,  64.,  48.,  16.,  84.,  18.,  48.,  20.,  24.,  54.,
         9.,  99.,  56., 140.,  72.,  20., 112.,   8., 110., 120.,  35.,
        63.,  21.,  99.,  36.,  72.,  16.,  77., 150.,  50.,  90.,  78.,
        60.,  81., 104.,  45.,  56.,   7.,  10.,  60.,  56.,  96.,  72.,
        28.,  40.,  72.,  78.,  18.,  54., 110.,   8.,  16.,  84., 130.,
        88.,  90.,  54., 100., 110.,  72.,  90.,  81.,   8.,  72.,  30.,
       140., 126., 105.,  36.,  18., 140.,  30.,  32.,  18.,  66.,  63.,
        24.,  78.,  21.,  16.,  32.,   9.,  28., 130.,  42.,  70., 105.,
        56., 135.,  63.,  45.,  72.,  72.,   6., 104.,  64.,  96.,  90.,
        20.,  84.,   7.,  90.,  63.,  42.,  60.,  72.,  49.,   9.,   6.,
        90., 130.,  90.,  90.,  42.,  35.,   9.,  45.,  40., 108.,  21.,
       108.,  30.,  84., 112., 135., 112., 112.,  18.,  84.,  50.,  40.,
         9.,  99.,  81.,  72.,  72.,  30.,  60.,  96.,  27., 140.,  60.,
        90.,  72.,  42.,  72.,  81.,  80., 117.,  32., 135.,   8.,  36.,
        63.,  80.,  16., 120.,  72., 100., 110.,  48.,  42.,  64., 130.,
        48.,  90.,  84.,  54.,  48.,  54.,  18.,  80.,  49.,  84., 150.,
        78., 126.,  63.,   9.,  16.,  50., 120.,   8.,  32.,  56., 135.,
        16.,  77.,  24.,  60.,  48.,  18.,   8.,  70.,  63.,  54.,  91.,
        80., 112.,  70., 120., 120., 120.,   8.,  56.,  12.,  88.,  28.,
        18.,  81.,  48.,  91., 117.,  42.,  49., 140.,  28., 120.,  56.,
       110., 130.,  72.,  18.,  77., 126.,  32.,  42.,  36.,  16.,   9.,
        88.,  54.,  72.,  30., 126.,  88.,  84.,  24.,  60., 117., 104.,
       120.,  77., 105.,  42., 110.,  88.,  56.,  35.,  42.,  80.,  30.,
        50.,  48.,  24.,  21.,  56.,  72.,   9.,  63.,  98.,  60.,  48.,
        16., 117.,  30.,  70., 104.,  49.,  21., 130.,  56., 117.,  78.,
         8.,  36.,  48.,  91.,  84.,  24.,  36.,  72.,  10.,  18.,  36.,
        80.,  90., 112.,  63.,  32.,  96.,  72., 108.,  80.,   9.,  18.,
        98.,  18.,  88.,  20.,  18.,  30.,  12.,  54.,  36.,  42., 120.,
        70.,  32.,   7.,  40.,  63.,  77.,  28.,  24.,   7.,  36.,  48.,
        54.,  10.,  56.,  42., 135.,  98.,  10.,  54.,  84.,  54.,  60.,
       117., 135.,  35., 117.,  72., 130.,  63., 110.,  21.,  81.,  48.,
       110.,  54.,  60.,  49.,  91.,  72.,  48.,  10.,  77.,  72., 112.,
        45., 150.,  88., 150., 135., 140.,  32.,  70.,  80.,  72.,  91.])

Get the RMSE for the training set

In [34]:
mse_full_tree_train = sklearn.metrics.mean_squared_error(train_y, train_y_pred_full)
mse_full_tree_train
Out[34]:
0.0
In [35]:
import math
In [36]:
rmse_full_tree_train = math.sqrt(mse_full_tree_train)
rmse_full_tree_train
Out[36]:
0.0
In [37]:
# If using the dmba package, install it first:

# pip install dmba

# or

# conda install -c conda-forge dmba


# Then load the library

# import dmba

# from dmba import regressionSummary


import dmba
from dmba import regressionSummary
In [38]:
regressionSummary(train_y, train_y_pred_full)
Regression statistics

                      Mean Error (ME) : 0.0000
       Root Mean Squared Error (RMSE) : 0.0000
            Mean Absolute Error (MAE) : 0.0000
          Mean Percentage Error (MPE) : 0.0000
Mean Absolute Percentage Error (MAPE) : 0.0000

On the validation set

In [39]:
valid_y_pred_full = full_tree.predict(valid_X)
valid_y_pred_full
Out[39]:
array([110.,   9.,  70., 104.,  84.,  36., 150.,  35.,  90.,  20.,  32.,
       130., 117.,  96., 120.,  21.,  70.,  35.,  18.,  54.,  64.,  27.,
        32.,  72., 126.,  88.,  30.,  20.,  40.,   9., 126., 112.,  21.,
        54., 110.,   9.,  21.,  70., 140., 110.,  72.,  36., 117., 105.,
        72.,  18.,   8.,  54.,  81.,  40.,  36., 135.,  90., 112.,  72.,
        91., 110.,  63., 100.,  27., 110.,   7.,  30.,  90.,  40.,  16.,
        32., 104.,  48.,  90.,  12., 140.,  36., 135.,  16.,  18.,  60.,
        96.,  84.,  54.,  30.,  42.,  80.,  28.,  30.,  16., 130.,   8.,
        90.,  24.,  40., 117.,  49., 140.,  20.,  42.,  72.,  90., 108.,
        42.,  60.,  80., 135.,  48.,  42.,  84.,  91.,  72.,  84.,  36.,
        50.,  50.,  56., 126.,  28.,  42.,  80.,  36.,  96.,  36.,   8.,
        72.,  10.,  72.,  77.,  10., 135.,  54., 140.,  16., 140.,  40.,
         9.,  63.,   8.,  54., 150.,  63.,  56.,  77.,  84., 140., 135.,
         9.,  84.,  36.,   8.,  91.,  16.,  42.,   8., 110.,  56.,  36.,
        20.,  18.,  63., 117.,  81.,  98.,  81.,  84., 120.,  36.,   8.,
        27., 130.,  12.,  63.,   9.,  30.,  84., 130.,  16.,  56.,  90.,
        21.,  63.,  32.,  40.,  40., 140.,  60.,  84.,  18., 100.,  18.,
       126.,  36., 150.,  80.,  77.,  16.,  56.,  88., 110.,  18.,   8.,
        18.,  96.,  42.,  30.,  21., 110.,  54.,  60.,  42., 135., 126.,
       130.,  72., 117.,  72.,  21.,  90.,  70.,  48., 117.,  24.,  60.,
        45.,  91.,   8.,   6., 130.,  81.,  50.,  66.,  96., 130.,  63.,
        63.,   8.,  78.,  70.,  63.,   9.,  72.,  18.,  80.,  54.,  72.,
        56.,  40.,  84.,  60.,  70.,  49.,  64.,  30.,  54.,  28.,  72.,
       104.,  78.,  18.,  56.,  54.,   6.,  77.,  54., 120.,  56.,  54.,
        45.,  56., 126.,  63.,  72.,  80.,  50.,  63., 117.,  63., 112.,
        90., 117.,  36.,  42.,  90., 110., 135.,  56.,  60.,  20.,  56.,
        81.,  84., 150.,   9., 120.,  42.,  72.,  21.])

Get the RMSE for the validation set

In [40]:
mse_full_tree_valid = sklearn.metrics.mean_squared_error(valid_y, valid_y_pred_full)
mse_full_tree_valid
Out[40]:
2643.622448979592
In [41]:
import math
In [42]:
rmse_full_tree_valid = math.sqrt(mse_full_tree_valid)
rmse_full_tree_valid
Out[42]:
51.41616913947977
In [43]:
# If using the dmba package, install it first:

# pip install dmba

# or

# conda install -c conda-forge dmba


# Then load the library

# import dmba

# from dmba import regressionSummary


regressionSummary(valid_y, valid_y_pred_full)
Regression statistics

                      Mean Error (ME) : 3.7041
       Root Mean Squared Error (RMSE) : 51.4162
            Mean Absolute Error (MAE) : 41.2415
          Mean Percentage Error (MPE) : -47.2748
Mean Absolute Percentage Error (MAPE) : 99.1131

3.3.2 Predictions Using the Small Tree¶

On the training set

In [44]:
train_y_pred = small_tree.predict(train_X)
train_y_pred
Out[44]:
array([69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       61.80769231, 40.09090909, 64.47619048, 69.37151703, 69.37151703,
       69.37151703, 56.80555556, 45.66666667, 69.37151703, 69.37151703,
       69.37151703, 56.80555556, 64.47619048, 69.37151703, 69.37151703,
       69.37151703, 45.66666667, 69.37151703, 56.80555556, 40.09090909,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 61.80769231, 61.80769231, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556,
       56.80555556, 61.80769231, 69.37151703, 69.37151703, 61.80769231,
       69.37151703, 56.80555556, 64.47619048, 64.47619048, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231,
       69.37151703, 61.80769231, 69.37151703, 56.80555556, 69.37151703,
       61.80769231, 61.80769231, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703,
       56.80555556, 69.37151703, 69.37151703, 69.37151703, 56.80555556,
       56.80555556, 69.37151703, 40.09090909, 69.37151703, 56.80555556,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703,
       61.80769231, 69.37151703, 56.80555556, 64.47619048, 56.80555556,
       69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 56.80555556, 40.09090909, 69.37151703, 61.80769231,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703,
       56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       64.47619048, 69.37151703, 69.37151703, 61.80769231, 69.37151703,
       69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703,
       69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 61.80769231, 69.37151703, 69.37151703, 40.09090909,
       64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 64.47619048, 69.37151703, 61.80769231, 40.09090909,
       69.37151703, 64.47619048, 64.47619048, 56.80555556, 69.37151703,
       69.37151703, 56.80555556, 69.37151703, 56.80555556, 69.37151703,
       69.37151703, 69.37151703, 40.09090909, 56.80555556, 69.37151703,
       56.80555556, 69.37151703, 69.37151703, 69.37151703, 64.47619048,
       69.37151703, 69.37151703, 69.37151703, 56.80555556, 40.09090909,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909,
       40.09090909, 64.47619048, 69.37151703, 61.80769231, 61.80769231,
       45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       56.80555556, 69.37151703, 69.37151703, 56.80555556, 69.37151703,
       40.09090909, 69.37151703, 45.66666667, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703,
       69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703,
       61.80769231, 69.37151703, 56.80555556, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 64.47619048, 45.66666667, 45.66666667, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       56.80555556, 69.37151703, 45.66666667, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       64.47619048, 61.80769231, 69.37151703, 69.37151703, 69.37151703,
       64.47619048, 69.37151703, 40.09090909, 69.37151703, 69.37151703,
       69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 56.80555556, 61.80769231, 69.37151703,
       61.80769231, 40.09090909, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 40.09090909, 69.37151703, 64.47619048,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703,
       69.37151703, 64.47619048, 56.80555556, 69.37151703, 64.47619048,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909,
       69.37151703, 61.80769231, 69.37151703, 56.80555556, 69.37151703,
       69.37151703, 69.37151703, 40.09090909, 64.47619048, 69.37151703,
       61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 61.80769231, 69.37151703, 69.37151703, 56.80555556,
       61.80769231, 69.37151703, 69.37151703, 64.47619048, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 45.66666667,
       45.66666667, 69.37151703, 61.80769231, 56.80555556, 56.80555556,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703])

Get the RMSE for the training set

In [45]:
mse_small_tree_train = sklearn.metrics.mean_squared_error(train_y, train_y_pred)
mse_small_tree_train
Out[45]:
1320.9654960245605
In [46]:
import math
In [47]:
rmse_small_tree_train = math.sqrt(mse_small_tree_train)
rmse_small_tree_train
Out[47]:
36.34508902210256
In [48]:
# If using the dmba package, install it first:

# pip install dmba

# or

# conda install -c conda-forge dmba


# Then load the library

# import dmba

# from dmba import regressionSummary


import dmba
from dmba import regressionSummary
In [49]:
regressionSummary(train_y, train_y_pred)
Regression statistics

                      Mean Error (ME) : 0.0000
       Root Mean Squared Error (RMSE) : 36.3451
            Mean Absolute Error (MAE) : 30.4192
          Mean Percentage Error (MPE) : -75.0378
Mean Absolute Percentage Error (MAPE) : 103.4166

On the validation set

In [50]:
valid_y_pred = small_tree.predict(valid_X)
valid_y_pred
Out[50]:
array([69.37151703, 64.47619048, 61.80769231, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 61.80769231, 69.37151703, 56.80555556,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703,
       69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703,
       61.80769231, 45.66666667, 61.80769231, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 61.80769231, 56.80555556, 56.80555556,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048,
       56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 56.80555556, 69.37151703, 56.80555556,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909,
       69.37151703, 69.37151703, 69.37151703, 61.80769231, 61.80769231,
       69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703,
       64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703,
       64.47619048, 61.80769231, 56.80555556, 69.37151703, 56.80555556,
       64.47619048, 69.37151703, 69.37151703, 69.37151703, 45.66666667,
       69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       40.09090909, 69.37151703, 61.80769231, 69.37151703, 56.80555556,
       61.80769231, 56.80555556, 69.37151703, 40.09090909, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703,
       69.37151703, 69.37151703, 61.80769231, 64.47619048, 64.47619048,
       56.80555556, 56.80555556, 69.37151703, 69.37151703, 69.37151703,
       40.09090909, 69.37151703, 64.47619048, 69.37151703, 69.37151703,
       69.37151703, 56.80555556, 69.37151703, 64.47619048, 69.37151703,
       45.66666667, 69.37151703, 69.37151703, 56.80555556, 69.37151703,
       40.09090909, 69.37151703, 61.80769231, 69.37151703, 45.66666667,
       69.37151703, 64.47619048, 69.37151703, 69.37151703, 45.66666667,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048,
       64.47619048, 69.37151703, 45.66666667, 56.80555556, 64.47619048,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 61.80769231, 40.09090909, 61.80769231, 69.37151703,
       69.37151703, 61.80769231, 56.80555556, 69.37151703, 64.47619048,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703,
       69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703,
       45.66666667, 40.09090909, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       56.80555556, 69.37151703, 69.37151703, 40.09090909, 69.37151703,
       61.80769231, 69.37151703, 69.37151703, 64.47619048, 69.37151703,
       69.37151703, 61.80769231, 45.66666667, 69.37151703, 69.37151703,
       61.80769231, 45.66666667, 69.37151703, 69.37151703, 56.80555556,
       40.09090909, 64.47619048, 69.37151703, 40.09090909, 69.37151703,
       45.66666667, 69.37151703, 64.47619048, 64.47619048, 61.80769231,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       56.80555556, 69.37151703, 69.37151703, 40.09090909, 40.09090909,
       69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703,
       69.37151703, 69.37151703, 45.66666667, 69.37151703, 69.37151703,
       69.37151703, 45.66666667, 40.09090909, 69.37151703, 45.66666667,
       64.47619048, 69.37151703, 56.80555556, 69.37151703])

Get the RMSE for the validation set

In [51]:
mse_small_tree_valid = sklearn.metrics.mean_squared_error(valid_y, valid_y_pred)
mse_small_tree_valid
Out[51]:
1353.8427521618253
In [52]:
import math
In [53]:
rmse_small_tree_valid = math.sqrt(mse_small_tree_valid)
rmse_small_tree_valid
Out[53]:
36.79460221502368
In [54]:
# If using the dmba package, install it first:

# pip install dmba

# or

# conda install -c conda-forge dmba


# Then load the library

# import dmba

# from dmba import regressionSummary


regressionSummary(valid_y, valid_y_pred)
Regression statistics

                      Mean Error (ME) : 4.2030
       Root Mean Squared Error (RMSE) : 36.7946
            Mean Absolute Error (MAE) : 30.2375
          Mean Percentage Error (MPE) : -48.2303
Mean Absolute Percentage Error (MAPE) : 80.0992
In [55]:
train_y.describe()
Out[55]:
count    440.000000
mean      65.552273
std       37.217785
min        6.000000
25%       35.000000
50%       63.000000
75%       90.000000
max      150.000000
Name: HP, dtype: float64

4. New Records¶

New records

In [56]:
new_dnd_df = pd.read_csv("new_records_dnd.csv")
new_dnd_df
Out[56]:
STR DEX CON INT WIS CHA
0 9 17 8 13 16 15
1 17 9 17 18 11 7

Using the small tree

In [57]:
new_records_dnd_small_pred = small_tree.predict(new_dnd_df)
new_records_dnd_small_pred
Out[57]:
array([69.37151703, 40.09090909])
In [58]:
import pandas as pd
dnd_small_tree_prediction_df = pd.DataFrame(new_records_dnd_small_pred,
                                         columns = ["Prediction"])
dnd_small_tree_prediction_df
Out[58]:
Prediction
0 69.371517
1 40.090909

Merge with new data

In [59]:
new_dnd_df_with_prediction = pd.concat((new_dnd_df, dnd_small_tree_prediction_df), axis = 1)
new_dnd_df_with_prediction

# to export
# new_dnd_df_with_prediction.to_csv("whatever_name.csv")
Out[59]:
STR DEX CON INT WIS CHA Prediction
0 9 17 8 13 16 15 69.371517
1 17 9 17 18 11 7 40.090909

4.1 Range of Predictions¶

Get the leaf number

In [60]:
leaf_number_for_new =  small_tree.apply(new_dnd_df)
leaf_number_for_new
Out[60]:
array([6, 3], dtype=int64)
In [61]:
leaf_number_for_new_df = pd.DataFrame(leaf_number_for_new, columns = ["leaf_number"])
leaf_number_for_new_df
Out[61]:
leaf_number
0 6
1 3
In [62]:
new_dnd_df_with_prediction_small_tree_leaf_number = pd.concat((new_dnd_df_with_prediction,
                                                               leaf_number_for_new_df), 
                                                              axis = 1)
new_dnd_df_with_prediction_small_tree_leaf_number
Out[62]:
STR DEX CON INT WIS CHA Prediction leaf_number
0 9 17 8 13 16 15 69.371517 6
1 17 9 17 18 11 7 40.090909 3

Get the values of each leaf

In [63]:
leaf_number = pd.DataFrame(small_tree.apply(train_X), columns = ["leaf_number"], index = train_y.index)
leaf_number
Out[63]:
leaf_number
650 6
479 6
271 6
647 6
307 6
... ...
445 6
414 6
70 6
429 6
236 6

440 rows × 1 columns

Get the HP of each record and corresponding leaf assignment

In [64]:
leaf_df = pd.concat([leaf_number, train_y], axis = 1)
leaf_df
Out[64]:
leaf_number HP
650 6 117
479 6 120
271 6 72
647 6 117
307 6 100
... ... ...
445 6 32
414 6 70
70 6 80
429 6 72
236 6 91

440 rows × 2 columns

Various descriptive stats of each leaf

In [65]:
leaf_max_df = leaf_df.groupby(by = "leaf_number").max()
leaf_max_df
Out[65]:
HP
leaf_number
2 120
3 112
6 150
7 140
9 140
10 112
In [66]:
leaf_max_df = leaf_max_df.rename(columns = {"HP": "Max_HP"})
leaf_max_df
Out[66]:
Max_HP
leaf_number
2 120
3 112
6 150
7 140
9 140
10 112
In [67]:
leaf_min_df = leaf_df.groupby(by = "leaf_number").min()
leaf_min_df
Out[67]:
HP
leaf_number
2 9
3 6
6 7
7 7
9 6
10 6
In [68]:
leaf_min_df = leaf_min_df.rename(columns = {"HP": "Min_HP"})
leaf_min_df
Out[68]:
Min_HP
leaf_number
2 9
3 6
6 7
7 7
9 6
10 6
In [69]:
leaf_std_df = leaf_df.groupby(by = "leaf_number").std()
leaf_std_df
Out[69]:
HP
leaf_number
2 27.924217
3 29.395350
6 37.600194
7 35.875037
9 38.203685
10 31.209944
In [70]:
leaf_std_df = leaf_std_df.rename(columns = {"HP": "std_HP"})
leaf_std_df
Out[70]:
std_HP
leaf_number
2 27.924217
3 29.395350
6 37.600194
7 35.875037
9 38.203685
10 31.209944

Put them all together

In [71]:
new_dnd_df_with_prediction_small_tree_leaf_number_range = pd.merge(
    pd.merge(
        pd.merge(new_dnd_df_with_prediction_small_tree_leaf_number,leaf_max_df, how = "inner", on = "leaf_number"),
        leaf_min_df, how = "inner", on = "leaf_number"),
leaf_std_df, how = "inner", on = "leaf_number")
new_dnd_df_with_prediction_small_tree_leaf_number_range
Out[71]:
STR DEX CON INT WIS CHA Prediction leaf_number Max_HP Min_HP std_HP
0 9 17 8 13 16 15 69.371517 6 150 7 37.600194
1 17 9 17 18 11 7 40.090909 3 112 6 29.395350

d20.jpeg

5. Random Forest¶

5.1 Fit the random forest¶

In [72]:
from sklearn.ensemble import RandomForestRegressor

rf = RandomForestRegressor(max_depth = 10, random_state = 666)
rf.fit(train_X, train_y)
Out[72]:
RandomForestRegressor(max_depth=10, random_state=666)
In [73]:
train_y_pred_rf = rf.predict(train_X)
train_y_pred_rf
Out[73]:
array([ 88.93415507, 103.4440601 ,  69.47537521, 102.25756247,
        90.6015338 ,  72.34565388,  55.08351732,  61.08039394,
        56.41217774,  55.12831922,  44.96499255,  50.06883454,
        49.50479846,  71.23535818,  62.30711203,  38.00342379,
        96.4850522 ,  88.06      ,  74.44247935,  39.28533981,
        65.46959216,  35.64007071,  45.80873105,  45.26427884,
        49.58661111,  81.29242571,  49.98355509,  85.12770696,
        56.05009987,  79.14899567,  82.05170197,  48.303423  ,
        43.96298786,  63.78496467,  82.41620807,  70.45749862,
        47.52008668, 115.36753846,  47.96596143,  32.5633898 ,
        75.86841179,  59.54143651,  77.72521177,  51.08989052,
        21.19936315,  77.51343015,  62.72026195,  78.9389704 ,
        72.71332005,  77.33110815,  56.42808742,  89.85577218,
        46.64963591,  85.13207896,  54.96862511,  56.87054614,
        43.91997655,  81.48412106,  72.74496673,  86.26800778,
        41.57996127,  99.21362527, 108.10668571,  57.29779173,
       111.28059477,  53.60430871,  82.44322777,  82.40620381,
        77.08339146,  91.56549427,  63.02749817,  54.35300364,
        82.2662034 ,  97.93201389,  73.0482379 ,  75.82030556,
        77.8802475 ,  72.55873069,  37.73159091,  71.44797238,
        54.50111558,  28.51136153,  81.70698526,  41.82350859,
        54.03108768,  45.3973026 ,  45.78339763,  61.36807   ,
        30.11053968,  82.13027778,  61.15893505, 101.57288501,
        67.48741823,  26.93542019,  88.02408402,  37.45897826,
        85.53040787,  84.95445402,  47.64463051,  67.75539365,
        37.25539184,  91.29050568,  48.96801172,  78.61864337,
        33.39780771,  77.58149604, 118.14853662,  55.76287724,
        68.37750732,  69.79108469,  63.83683632,  66.15682478,
        87.99372067,  54.63183916,  60.71442133,  33.09455779,
        30.96605128,  60.4653631 ,  65.59826143,  81.64426887,
        60.19963427,  45.11525176,  49.20249957,  74.16530098,
        73.19454255,  39.59228974,  54.49690542,  95.66990974,
        35.73416306,  39.08250008,  79.12853771, 112.28103497,
        68.98634541,  80.03228299,  67.99279785,  82.48577546,
        85.0130968 ,  70.083967  ,  74.94647463,  72.72633574,
        43.42748338,  82.81276538,  45.62634505, 106.61943651,
       102.58836559,  80.38613257,  57.37117338,  33.88748496,
       108.94866082,  53.35486688,  56.18110066,  34.60665282,
        78.43217183,  64.39763654,  45.69437607,  68.45122386,
        52.6406046 ,  32.45174725,  48.9233243 ,  29.83394958,
        35.42764936, 109.65492857,  51.11937884,  69.64900305,
        89.00004645,  58.59010944, 107.6301434 ,  63.54540565,
        52.22194841,  71.08996857,  67.97638999,  27.7934586 ,
        80.31229295,  71.33814287,  77.0330639 ,  69.61092666,
        48.67007949,  85.91442843,  41.74883263,  75.18138508,
        77.07730257,  53.18972691,  66.60794683,  65.06615015,
        62.96954681,  32.4638915 ,  25.73599206,  69.36473102,
        80.72062339,  69.3157973 ,  79.51272729,  50.00133222,
        53.77316321,  52.01618708,  51.0263894 ,  49.63979731,
        79.75040143,  42.25055932,  87.69675137,  64.09286381,
        72.61794547,  92.84779201,  98.12629072,  86.48315375,
        90.06800595,  57.20494978,  81.81610457,  51.62287151,
        50.15992208,  32.62498061,  92.94263059,  68.53155923,
        66.82723705,  70.01569548,  51.41607791,  62.33188215,
        71.4641899 ,  43.85667954, 105.44633669,  57.67718836,
        77.16578731,  74.30349661,  53.70944274,  71.63464022,
        75.0198326 ,  69.05771916,  88.68029785,  44.97268484,
       107.37958222,  23.36269481,  49.00191044,  58.6598738 ,
        78.08707725,  35.00298126,  84.92576119,  76.44292787,
        94.72199013,  94.11666398,  59.28519292,  60.45311085,
        65.7338112 ,  95.74016892,  58.27571061,  77.21380488,
        67.20433333,  49.15215437,  54.62181655,  55.11792308,
        34.05189765,  64.52346018,  56.13229827,  81.13058491,
       113.66492836,  76.42701323,  88.65096755,  57.28039899,
        38.63262642,  48.68805405,  59.04782569, 102.37085767,
        31.38396037,  52.23307442,  53.99473035,  80.48299547,
        52.97368403,  68.96615525,  45.42954444,  61.45521545,
        46.90742883,  53.06130937,  47.32722966,  76.63714672,
        53.59161405,  63.17177048,  87.5244256 ,  71.14728393,
        89.9188027 ,  68.04064219,  81.02650193, 108.56730092,
        95.20197455,  40.77743978,  60.0310126 ,  33.02766208,
        75.63861591,  49.40175669,  33.68347467,  67.28065404,
        52.59024541,  90.46681485,  96.96412321,  55.5304823 ,
        67.5809228 , 101.79697451,  43.97335836,  99.45552434,
        69.62235006,  90.65557188, 109.08199052,  63.41604082,
        38.49531318,  72.52384339,  90.26417945,  42.18824789,
        48.05441438,  47.9575858 ,  57.17206151,  30.2311224 ,
        71.52939374,  62.88909123,  68.43390846,  43.77445167,
        95.76515043,  74.0386496 ,  70.54859477,  55.06728316,
        59.3325523 ,  97.05169676,  80.23437332,  99.06046429,
        66.32816194,  89.5258736 ,  61.41168678,  98.17453733,
        76.69305752,  60.18673817,  54.82025641,  57.32873945,
        76.85930604,  45.96900252,  58.76426003,  55.51669719,
        48.82278961,  43.78012896,  58.54019432,  72.18893748,
        32.63400242,  61.5212612 ,  72.75890845,  57.61661024,
        65.39866234,  47.41540043,  77.30090306,  50.89748999,
        61.50746927,  79.49715321,  41.20356734,  56.03800925,
        94.15462125,  63.20338365,  95.28639722,  75.81232017,
        32.09039495,  53.32924083,  59.63109694,  87.66349323,
        71.51329672,  42.87189277,  48.26400166,  60.3417937 ,
        42.20664573,  43.00066797,  55.18067154,  76.27390707,
        78.82137727,  85.451     ,  69.72150089,  55.34246911,
        80.58014985,  86.81124055,  83.98661409,  74.85431944,
        42.61049311,  38.15734865,  85.9956654 ,  43.26669312,
        78.97732552,  36.26190476,  41.72199711,  53.07704693,
        34.18176444,  53.47318892,  48.90472817,  58.57496796,
        92.79668924,  78.4475137 ,  53.60581926,  32.80670145,
        49.42504221,  57.3471786 ,  75.856811  ,  35.04339654,
        52.5227308 ,  34.48859646,  59.79505366,  64.33654553,
        54.4064929 ,  24.49359101,  57.92275495,  57.89097697,
       100.67209492,  88.89994674,  39.90751776,  58.73859727,
        76.14315942,  62.97781913,  71.07756578,  97.56604108,
        99.52729537,  44.67429365,  88.03345238,  68.59522164,
        98.84517573,  60.62938004,  87.14872698,  47.58679546,
        72.82187091,  65.94254085,  89.09710914,  61.96470938,
        69.36748378,  60.7044839 ,  81.60740097,  71.04032612,
        54.72223377,  32.06326537,  67.85120587,  68.54966934,
        90.59010516,  53.08508131, 114.62275726,  81.56617735,
       116.21760762, 100.6281484 , 107.36454736,  54.93407498,
        61.55746122,  74.83490502,  67.25198763,  75.9885209 ])
In [74]:
mse_rf_train = sklearn.metrics.mean_squared_error(train_y, train_y_pred_rf)
mse_rf_train
Out[74]:
356.848877962145
In [75]:
# import math

rmse_rf_train = math.sqrt(mse_rf_train)
rmse_rf_train
Out[75]:
18.89044409118391
In [76]:
# If using the dmba package, install it first:

# pip install dmba
# or
# conda install -c conda-forge dmba

# import dmba
# from dmba import regressionSummary

regressionSummary(train_y, train_y_pred_rf)
Regression statistics

                      Mean Error (ME) : -0.3851
       Root Mean Squared Error (RMSE) : 18.8904
            Mean Absolute Error (MAE) : 15.6621
          Mean Percentage Error (MPE) : -38.5278
Mean Absolute Percentage Error (MAPE) : 52.8393
In [77]:
valid_y_pred_rf = rf.predict(valid_X)
valid_y_pred_rf
Out[77]:
array([85.99613623, 72.81197421, 57.14756602, 65.48415094, 64.23185644,
       73.88297094, 75.63706658, 65.88536111, 77.38755932, 53.3611746 ,
       84.30540558, 81.38338597, 69.54557467, 76.4453801 , 77.30100974,
       60.7457147 , 60.28698095, 54.54960764, 58.39220608, 79.51837662,
       57.45235255, 44.49966493, 80.43081232, 80.9672802 , 79.12526316,
       61.88250468, 68.48486885, 55.05741665, 54.63041239, 66.02297654,
       74.05684091, 61.26967588, 63.94936452, 68.44916792, 79.89559722,
       47.99857592, 77.75612193, 51.24598032, 91.96242819, 95.47827337,
       72.99748216, 61.48166667, 74.24409524, 71.74566751, 65.74643356,
       72.12616881, 59.34396605, 55.31680735, 75.01714982, 67.35775287,
       62.03485455, 69.73391667, 60.42239718, 49.83764828, 78.65800622,
       72.81000229, 84.1797256 , 73.57869858, 69.42034164, 54.17249117,
       80.91616771, 31.0329006 , 58.56498167, 86.61381342, 53.14061718,
       46.71116392, 51.25840051, 57.74422334, 70.76289968, 68.26478858,
       53.12914791, 67.50639943, 58.15775315, 86.5268843 , 54.6770392 ,
       67.69249941, 59.79266884, 74.1011205 , 80.20981113, 47.86735296,
       57.1792636 , 65.4681474 , 71.61567756, 64.21365135, 52.96507108,
       49.00585282, 60.93468812, 66.42856805, 72.62375   , 61.39575291,
       52.3028228 , 70.51631313, 61.68713413, 82.59035772, 62.97275458,
       65.08313092, 62.98707668, 54.98747072, 74.79348167, 77.76127609,
       53.28477675, 65.1869692 , 68.7451453 , 70.07741324, 62.38078571,
       63.3365805 , 73.22574008, 64.41808607, 59.03357159, 56.65995647,
       48.07595775, 66.59146202, 66.1300717 , 64.77757169, 58.36584383,
       63.15126245, 53.69026718, 61.60511499, 68.70015011, 68.68720783,
       61.64060036, 70.09213449, 73.00846138, 63.50531364, 54.47829147,
       51.04458733, 75.76759166, 50.68129187, 85.91485798, 45.009602  ,
       74.03758542, 60.60588345, 66.72822863, 50.40984565, 50.91525236,
       79.34774785, 74.27093714, 82.57503454, 84.23115898, 58.34441943,
       73.50657518, 76.91086606, 94.43031653, 51.995     , 66.51659984,
       59.95371429, 48.8428732 , 74.67547684, 54.12539596, 62.59740704,
       38.82594481, 66.68127848, 59.69954936, 62.24238866, 51.24587194,
       67.46654537, 50.08880019, 78.89875988, 70.45242092, 76.13153563,
       47.54842012, 71.55815842, 77.32311666, 51.02301605, 45.95923124,
       54.17618857, 72.52685606, 59.74554544, 61.51544039, 61.22026958,
       74.06011706, 85.25158929, 66.50520371, 71.80522197, 46.78741758,
       75.15755218, 64.5374446 , 56.70808703, 83.22767766, 59.33948239,
       48.14300962, 80.07204362, 65.39221338, 61.34071397, 50.61837513,
       76.18111707, 48.45899522, 82.61167027, 68.64237587, 69.85818255,
       66.21509905, 82.97757268, 45.1973547 , 61.34981061, 69.38518179,
       91.70802321, 70.23158458, 27.83719221, 59.99675521, 69.57862161,
       70.99210766, 60.44286601, 65.59296299, 70.98130797, 65.47639993,
       68.08674206, 68.73752076, 67.61220989, 74.31800031, 64.41745114,
       67.19080289, 78.22315147, 76.38173478, 56.18366362, 52.71594479,
       56.61547222, 73.43090887, 91.90809606, 59.87043638, 57.65063381,
       51.46242433, 78.90791805, 60.54316234, 51.72708547, 75.8761541 ,
       71.22114286, 43.23641919, 65.66313809, 94.86817947, 67.21091462,
       52.21235575, 72.71997829, 67.16296753, 60.14175456, 68.69688344,
       49.47745978, 61.98980765, 61.37528073, 55.09736371, 66.38018356,
       56.80601079, 64.55983748, 79.77329875, 49.94612653, 56.46429733,
       70.81355007, 73.30237874, 69.29287879, 51.04525876, 57.27526319,
       55.77392348, 70.01296176, 74.8949364 , 83.63937886, 62.96207777,
       53.7513566 , 62.29764406, 69.50477559, 43.34733164, 80.23455717,
       61.96857608, 69.37674323, 37.29      , 65.15223851, 72.95181574,
       50.16673751, 75.97699246, 70.30796224, 64.95026118, 68.25145774,
       61.41429168, 65.10816719, 64.88187943, 68.41595238, 57.63969008,
       65.51860532, 88.77396387, 70.2144709 , 60.30475757, 55.3897939 ,
       61.63742355, 94.91610892, 72.07392157, 60.4653631 , 55.745187  ,
       73.56503645, 42.8781292 , 69.60228571, 84.44706084, 48.06643892,
       71.47878175, 57.75287869, 66.18710983, 83.19506918])
In [78]:
mse_rf_valid = sklearn.metrics.mean_squared_error(valid_y, valid_y_pred_rf)
mse_rf_valid
Out[78]:
1426.6006551155326
In [79]:
# import math

rmse_rf_valid = math.sqrt(mse_rf_valid)
rmse_rf_valid
Out[79]:
37.7703674209761
In [80]:
# If using the dmba package, install it first:

# pip install dmba
# or
# conda install -c conda-forge dmba

# import dmba
# from dmba import regressionSummary

regressionSummary(valid_y, valid_y_pred_rf)
Regression statistics

                      Mean Error (ME) : 3.4807
       Root Mean Squared Error (RMSE) : 37.7704
            Mean Absolute Error (MAE) : 31.1829
          Mean Percentage Error (MPE) : -48.5668
Mean Absolute Percentage Error (MAPE) : 81.0687

5.2 Predict using the random forest¶

In [81]:
new_records_dnd_rf_pred = rf.predict(new_dnd_df)
new_records_dnd_rf_pred
Out[81]:
array([64.32568489, 70.02311429])
In [82]:
dnd_rf_prediction_df = pd.DataFrame(new_records_dnd_rf_pred,
                                 columns = ["Prediction"])
dnd_rf_prediction_df
Out[82]:
Prediction
0 64.325685
1 70.023114

Combine with new data set

In [83]:
new_dnd_df_with_prediction_rf = pd.concat((new_dnd_df, dnd_rf_prediction_df), axis = 1)
new_dnd_df_with_prediction_rf

# to export
# new_dnd_df_with_prediction_rf.to_csv("whatever_name.csv")
Out[83]:
STR DEX CON INT WIS CHA Prediction
0 9 17 8 13 16 15 64.325685
1 17 9 17 18 11 7 70.023114

d20.jpeg