Nouman Rahman
ProgrammingFire

ProgrammingFire

What are Decision Trees in ML: with Scikit-Learn

What are Decision Trees in ML: with Scikit-Learn

Nouman Rahman's photo
Nouman Rahman
·Sep 25, 2022·

4 min read

Subscribe to my newsletter and never miss my upcoming articles

Play this article

Table of contents

  • Introduction
  • Decision Trees using Scikit-Learn
  • Feature Importance
  • Conclusion

Introduction

A decision tree in general parlance represents a hierarchical series of binary decisions. A decision tree in machine learning works in exactly the same way, and except that we let the computer figure out the optimal structure & hierarchy of decisions, instead of coming up with criteria manually.

How a Decision Tree looks like

Group 5.png Can you see how the model classifies a given input as a series of decisions? The tree is truncated here, but following any path from the root node down to a leaf will result in "Yes" or "No". Do you see how a decision tree differs from a logistic regression model?

How a Decision Tree is Created

Note the gini value in each box. This is the loss function used by the decision tree to decide which column should be used for splitting the data, and at what point the column should be split. A lower Gini index indicates a better split. A perfect split (only one class on each side) has a Gini index of 0.

Conceptually speaking, while training the models evaluates all possible splits across all possible columns and picks the best one. Then, it recursively performs an optimal split for the two portions. In practice, however, it's very inefficient to check all possible splits, so the model uses a heuristic (predefined strategy) combined with some randomization.

Decision Trees using Scikit-Learn

We can use DecisionTreeClassifier from sklearn.tree to train a decision tree.

Training

from sklearn.tree import DecisionTreeClassifier
model = DecisionTreeClassifier(random_state=50)
model.fit(X_train, train_targets)

Evaluation

Let's evaluate the decision tree using the accuracy_score.

from sklearn.metrics import accuracy_score, confusion_matrix
train_preds = model.predict(X_train)
train_preds

array(['No', 'No', 'No', ..., 'No', 'No', 'No'], dtype=object)

The decision tree also returns probabilities for each prediction.

train_probs = model.predict_proba(X_train)
train_probs

array([[1., 0.],
       [1., 0.],
       [1., 0.],
       ...,
       [1., 0.],
       [1., 0.],
       [1., 0.]])

Seems like the decision tree is quite confident about its predictions.

Let's check the accuracy of its predictions.

accuracy_score(train_targets, train_preds)

0.9999797955307714

The training set accuracy is close to 100%! But we can't rely solely on the training set accuracy, we must evaluate the model on the validation set too.

We can make predictions and compute accuracy in one step using model.score

model.score(X_val, val_targets)

0.7921188555510418

Although the training accuracy is 100%, the accuracy on the validation set is just about 79%, which is only marginally better than always predicting "No".

It appears that the model has learned the training examples perfectly, and doesn't generalize well to previously unseen examples. This phenomenon is called "overfitting", and reducing overfitting is one of the most important parts of any machine learning project.

Visualization

We can visualize the decision tree learned from the training data.

from sklearn.tree import plot_tree, export_text
plt.figure(figsize=(80,20))
plot_tree(model, feature_names=X_train.columns, max_depth=2, filled=True);

Group 5.png

Let's check the depth of the tree that was created.

model.tree_.max_depth

We can also display the tree as text, which can be easier to follow for deeper trees.

tree_text = export_text(model, max_depth=10, feature_names=list(X_train.columns))
print(tree_text[:5000])

Feature Importance

Based on the gini index computations, a decision tree assigns an "importance" value to each feature. These values can be used to interpret the results given by a decision tree.

model.feature_importances_

array([3.48942086e-02, 3.23605486e-02, 5.91385668e-02, 2.49619907e-02,
       4.94652143e-02, 5.63334673e-02, 2.80205998e-02, 2.98128801e-02,
       4.02182908e-02, 2.61441297e-01, 3.44145027e-02, 6.20573699e-02,
       1.36406176e-02, 1.69229866e-02, ...])

Let's turn this into a data frame and visualize the most important features.

importance_df = pd.DataFrame({
    'feature': X_train.columns,
    'importance': model.feature_importances_
}).sort_values('importance', ascending=False)
importance_df.head(10)
featureimportance
9Humidity3pm0.261441
11Pressure3pm0.062057
2Rainfall0.059139
5WindGustSpeed0.056333
4Sunshine0.049465
8Humidity9am0.040218
14Temp9am0.035000
0MinTemp0.034894
10Pressure9am0.034415
1MaxTemp0.032361
plt.title('Feature Importance')
sns.barplot(data=importance_df.head(10), x='importance', y='feature');

Group 6 (1).png

Conclusion

In conclusion, decision trees are a powerful machine learning technique for both regression and classification. They are easy to interpret and explain, and they can handle both categorical and numerical data. However, decision trees can be prone to overfitting, especially when they are not pruned. If you are considering using decision trees for your machine learning project, be sure to keep this in mind. If you like this article, please consider sponsoring me. If you have any questions, please ask them in the comments or on Twitter. Thanks for reading!

Did you find this article valuable?

Support Nouman Rahman by becoming a sponsor. Any amount is appreciated!

See recent sponsors Learn more about Hashnode Sponsors
 
Share this