What are Decision Trees in ML: with Scikit-Learn

What are Decision Trees in ML: with Scikit-Learn

Sep 25, 2022ยท

4 min read

Play this article


A decision tree in general parlance represents a hierarchical series of binary decisions. A decision tree in machine learning works in exactly the same way, and except that we let the computer figure out the optimal structure & hierarchy of decisions, instead of coming up with criteria manually.

How a Decision Tree looks like

Group 5.png Can you see how the model classifies a given input as a series of decisions? The tree is truncated here, but following any path from the root node down to a leaf will result in "Yes" or "No". Do you see how a decision tree differs from a logistic regression model?

How a Decision Tree is Created

Note the gini value in each box. This is the loss function used by the decision tree to decide which column should be used for splitting the data, and at what point the column should be split. A lower Gini index indicates a better split. A perfect split (only one class on each side) has a Gini index of 0.

Conceptually speaking, while training the models evaluates all possible splits across all possible columns and picks the best one. Then, it recursively performs an optimal split for the two portions. In practice, however, it's very inefficient to check all possible splits, so the model uses a heuristic (predefined strategy) combined with some randomization.

Decision Trees using Scikit-Learn

We can use DecisionTreeClassifier from sklearn.tree to train a decision tree.


from sklearn.tree import DecisionTreeClassifier
model = DecisionTreeClassifier(random_state=50)
model.fit(X_train, train_targets)


Let's evaluate the decision tree using the accuracy_score.

from sklearn.metrics import accuracy_score, confusion_matrix
train_preds = model.predict(X_train)

array(['No', 'No', 'No', ..., 'No', 'No', 'No'], dtype=object)

The decision tree also returns probabilities for each prediction.

train_probs = model.predict_proba(X_train)

array([[1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.]])

Seems like the decision tree is quite confident about its predictions.

Let's check the accuracy of its predictions.

accuracy_score(train_targets, train_preds)


The training set accuracy is close to 100%! But we can't rely solely on the training set accuracy, we must evaluate the model on the validation set too.

We can make predictions and compute accuracy in one step using model.score

model.score(X_val, val_targets)


Although the training accuracy is 100%, the accuracy on the validation set is just about 79%, which is only marginally better than always predicting "No".

It appears that the model has learned the training examples perfectly, and doesn't generalize well to previously unseen examples. This phenomenon is called "overfitting", and reducing overfitting is one of the most important parts of any machine learning project.


We can visualize the decision tree learned from the training data.

from sklearn.tree import plot_tree, export_text
plot_tree(model, feature_names=X_train.columns, max_depth=2, filled=True);

Group 5.png

Let's check the depth of the tree that was created.


We can also display the tree as text, which can be easier to follow for deeper trees.

tree_text = export_text(model, max_depth=10, feature_names=list(X_train.columns))

Feature Importance

Based on the gini index computations, a decision tree assigns an "importance" value to each feature. These values can be used to interpret the results given by a decision tree.


array([3.48942086e-02, 3.23605486e-02, 5.91385668e-02, 2.49619907e-02,
       4.94652143e-02, 5.63334673e-02, 2.80205998e-02, 2.98128801e-02,
       4.02182908e-02, 2.61441297e-01, 3.44145027e-02, 6.20573699e-02,
       1.36406176e-02, 1.69229866e-02, ...])

Let's turn this into a data frame and visualize the most important features.

importance_df = pd.DataFrame({
    'feature': X_train.columns,
    'importance': model.feature_importances_
}).sort_values('importance', ascending=False)
plt.title('Feature Importance')
sns.barplot(data=importance_df.head(10), x='importance', y='feature');

Group 6 (1).png


In conclusion, decision trees are a powerful machine learning technique for both regression and classification. They are easy to interpret and explain, and they can handle both categorical and numerical data. However, decision trees can be prone to overfitting, especially when they are not pruned. If you are considering using decision trees for your machine learning project, be sure to keep this in mind. If you like this article, please consider sponsoring me. If you have any questions, please ask them in the comments or on Twitter. Thanks for reading!

Did you find this article valuable?

Support Nouman Rahman by becoming a sponsor. Any amount is appreciated!