Computer scienceData scienceInstrumentsScikit-learnTraining ML models with scikit-learnClassification in scikit-learn

Decision tree with scikit-learn

15 minutes read

In this topic, you will learn how to implement a decision tree classifier with scikit-learn. In essence, the decision tree is a flowchart-like structure consisting of many conditions - decision rules that are generated during the learning process.

We will learn how to train a decision tree and explore the parameters that affect the quality of the model.

Loading data

We'll use the Mines vs Rocks dataset. The dataset contains 208 samples of signals recorded by sonar from two object classes: rocks and mines. In the first 60 columns, you can see wavelengths that the sonar received from different angles. The label column contains two categorical values: 0 - mine, 1 - rock.

First, let's load the data and split it into train and test sets:

import pandas as pd
from sklearn.model_selection import train_test_split

link = "http://archive.ics.uci.edu/ml/machine-learning-databases/undocumented/connectionist-bench/sonar/sonar.all-data"
df_sonar = pd.read_csv(link, header=None)
df_sonar[60].replace(['M','R'], [0, 1], inplace=True)

X = df_sonar.drop(60, axis=1)
y = df_sonar[60]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=34)

df_sonar.head()

Output (the columns from 5th to 57th are hidden):

   0	   1	  2      3    ...  58     59	60
0 0.0200 0.0371	0.0428 0.0207 ... 0.0084 0.0032  1
1 0.0453 0.0523	0.0843 0.0689 ... 0.0049 0.0044  1
2 0.0262 0.0582	0.1099 0.1083 ... 0.0164 0.0078  1
3 0.0100 0.0171	0.0623 0.0205 ... 0.0044 0.0117  1

Fitting a model

To train a model, we will use the DecisionTreeClassifier class from sklearn. Let's create a decision tree model:

from sklearn.tree import DecisionTreeClassifier

# Creating a DecisionTreeClassifier object
clf = DecisionTreeClassifier(random_state=34)

# Training a model
clf = clf.fit(X_train, y_train)
# >>> DecisionTreeClassifier(random_state=34)

As you can see, we've defined a random state parameter for our model. What does this mean? The random state ensures that the results your model obtains are reproducible. Try it yourself: train decision tree models with different random states and then compare the predictions.

After the model has been fit, we can access the following properties of the model:

  • classes_ - a list of class labels.

  • n_features_ - the number of features.

  • feature_importances_ - feature importances that are calculated as the decrease in node impurity weighted by the probability of reaching that node. The higher the value, the more important the feature is.

You could read about all the other model properties in the sklearn documentation.

Making predictions

Now we will make a prediction for the training data. Let's pick some record from the dataset (for example, X[55]) and compare the prediction result with the target value from the dataset.

pred_train = clf.predict(X_train)
prediction = pred_train[55]
result = y_train.iloc[55]
print(prediction, result)
# >>> 1 1

See, the prediction result and the value from the target set are the same. Not bad! But let's take a look at the big picture. In this topic, we will use the score() method to evaluate our model. It returns the mean accuracy for train and test data.

train_score = clf.score(X_train, y_train)
test_score = clf.score(X_test, y_test)
print("Accuracy on train set: {}".format(train_score))
print("Accuracy on test set: {}".format(round(test_score, 3)))

# >>> Accuracy on train set: 1.0
# >>> Accuracy on test set: 0.725

The result of 72% is considered to be quite good. There are other ways to evaluate a decision tree model, but we won't go in-depth into decision tree evaluation metrics in this topic.

Decision tree parameters

Above, we've created a DecisionTreeClassifier object without passing any parameters. However, the DecisionTreeClassifier class offers a lot of possibilities to tune our decision tree model and make it perform better. Let's list the most essential parameters:

  • criterion (default: 'gini', 'gini'/ 'entropy'/ 'log_loss')

From the previous topic, you know that there are different criteria for finding the best split of the data in each node. In the criterion parameter, you can specify to the classifier which criterion it should apply to measure the quality of a split. Supported criteria are 'gini' for the Gini impurity and both 'entropy' and 'log_loss' for the information gain. By default, the classifier uses Gini impurity, so let's set it to 'entropy'.

clf = DecisionTreeClassifier(criterion='entropy', random_state=34)
clf.fit(X_train, y_train)
# >>> DecisionTreeClassifier(criterion='entropy', random_state=34)

print(clf.score(X_test, y_test))
# >>> 0.7681159420289855

By splitting based on entropy criterion we gained a better accuracy value - 0.725 for gini vs 0.768 for entropy.

  • max_depth (default: None, int)

The next parameter - max_depth - is the maximum depth of a tree, that is the length of the longest path from the root to a leaf. It's one of the parameters responsible for stopping the splitting. First, let's find the depth of our model. For this purpose, our classifier has the function get_depth().

clf.get_depth()
# >>> 5

Let's pass 2 to max_depth and see what happens:

clf = DecisionTreeClassifier(random_state=34, max_depth=2)
clf.fit(X_train, y_train)

# For training data:
print(clf.score(X_train, y_train))
# >>> 0.7697841726618705

# For test set:
print(clf.score(X_test, y_test))
# >>> 0.6231884057971014

The performance of our model on both the training and test data has worsened; this is because our model was underfitted. In such a setup, the tree can't have a depth greater than 2. This leads to our model not being complex enough to capture the relations between features and the target variable. As a result, it performs worse on both the training and test data. However, it can be detrimental in both ways: a higher value of max_depth may cause overfitting. In this case, the model has learned too well to predict on the training data and has lost the ability to generalize predictions. So, we'd get an excellent score on the training data and a much less satisfactory score on the test data.

The quality of the prediction depends on the tree's depth. So, the task is to find the depth at which there will be the smallest gap between the scores on the test and train data. For now, we can simply do this by training models with different max_depth values and comparing model performances on the train and test data. Later, you'll learn about more advanced and efficient techniques to find the best parameters.

  • min_samples_split & min_samples_leaf

max_depth is not the only way to tell a classifier when it must stop splitting. We can also set a minimum number of samples required for splitting. That is, if there are less than a certain number of samples in a node, the split won't happen. To set this minimum number of samples, we should specify the min_samples_split parameter.

Another parameter, min_samples_leaf, specifies the minimum number of samples required to be at a leaf node. In other words, a split won't happen if there isn't a certain number of samples to be both in the left and right branches after splitting.

  • max_features (default: None, int, float or from {'auto', 'sqrt', 'log2'})

The last but not least important parameter is max_features. It determines the maximum number of features to be considered while looking for the best split. Not only an integer or a float number but also sqrt(square root) or log2 (binary logarithm) can be passed to the parameter. When should we specify max_features? It's suitable when the dataset is too big and we wish to reduce the training time of our model, which directly depends on the number of features.

Conclusion

In this topic, we learned how to:

  • Train a decision tree classifier;

  • Change a criterion that measures the quality of a split;

  • Control the tree's depth to avoid overfitting;

  • Limit the number of features to optimize our model.

You can find more information about the DecisionTreeClassifier in the documentation.

25 learners liked this piece of theory. 3 didn't like it. What about you?
Report a typo