visualize decision tree in python with graphviz

Visualize Decision Tree

Visualize Decision Tree

How to visualize a decision tree in Python

The decision tree classifier is the most popularly used supervised learning algorithm. Unlike other classification algorithms, the decision tree classifier is not a black box in the modeling phase.  What that’s means, we can visualize the trained decision tree to understand how the decision tree gonna work for the give input features.

So in this article, you are going to learn how to visualize the trained decision tree model in Python with Graphviz. So let’s begin with the table of contents.

How to visualize decision tree in Python Click To Tweet

Table of contents

  • A basic introduction to decision tree classifier
  • Fruit classification with decision tree classifier
  • Why we need to visualize the trained decision tree
  • Understand the visualized decision tree
  • Visualize decision tree in python
    • What is Graphviz
    • Visualize the decision tree online
    • Visualize the decision tree as pdf

Introduction to Decision tree classifier

The decision tree classifier is mostly used classification algorithm because of its advantages over other classification algorithms. When we say the advantages it’s not about the accuracy of the trained decision tree model. It’s all about the usage and understanding of the algorithm.

Decision tree advantages:

  • Implementation wise building decision tree algorithm is so simple.
  • The trained decision tree can use for both classification and regression problems.
  • The complexity-wise decision tree is logarithmic in the number of observations in the training dataset.
  • The trained decision tree can visualize.

As we knew the advantages of using the decision tree over other classification algorithms. Now let’s look at the basic introduction to the decision tree.

If you go through the article about the working of decision tree classifiers in machine learning. You could aware of the decision tree keywords like root node, leaf node, information gain, Gini index, tree pruning ..etc

The above keywords used to give you the basic introduction to the decision tree classifier. If new to the decision tree classifier, Please spend some time on the below articles before you continue reading about how to visualize the decision tree in Python.

The decision tree classifier is a classification model that creates a set of rules from the training dataset. Later the created rules used to predict the target class. To get a clear picture of the rules and the need for visualizing decision, Let build a toy kind of decision tree classifier. Later use the build decision tree to understand the need to visualize the trained decision tree.

Fruit classification with decision tree classifier

fruit classification with decision tree

Fruit classification with decision tree

The decision tree classifier will train using the apple and orange features, later the trained classifier can be used to predict the fruit label given the fruit features.

The fruit features is a dummy dataset. Below are the dataset features and the targets.

Weight (grams) Smooth (Range of 1 to 10) Fruit
170 9 1
175 10 1
180 8 1
178 8 1
182 7 1
130 3 0
120 4 0
130 2 0
138 5 0
145 6 0

 The dummy dataset having two features and targets.

  • Weight: Is the weight of the fruit in grams
  • Smooth: Is the smoothness of the fruit in the range of 1 to 10
  • Fruit: Is the target 1 means for apple and 0 means for orange.

Let’s follow the below workflow for modeling the fruit classifier.

  • Loading the required Python machine learning packages
  • Create and load the data in Pandas dataframe
  • Building the fruit classifier with decision tree algorithm
  • Predicting the fruit type from the trained classifier

Loading the required Python machine learning packages

# Required Python Packages
import pandas as pd
import numpy as np
from sklearn import tree

The required python machine learning packages for building the fruit classifier are Pandas, Numpy, and Scikit-learn

  • Pandas: For loading the dataset into dataframe, Later the loaded dataframe passed an input parameter for modeling the classifier.
  • Numpy: For creating the dataset and for performing the numerical calculation.
  • Sklearn: For training the decision tree classifier on the loaded dataset.

Now let’s create the dummy data set and load into the pandas dataframe

Create and load the data in Pandas dataframe

# creating dataset for modeling Apple / Orange classification
fruit_data_set = pd.DataFrame()
fruit_data_set["fruit"] = np.array([1, 1, 1, 1, 1,      # 1 for apple
                                    0, 0, 0, 0, 0])     # 0 for orange
fruit_data_set["weight"] = np.array([170, 175, 180, 178, 182,
                                     130, 120, 130, 138, 145])
fruit_data_set["smooth"] = np.array([9, 10, 8, 8, 7,
                                     3, 4, 2, 5, 6])
  • The empty pandas dataframe created for creating the fruit data set.
  • Using the NumPy created arrays for target, weight, smooth.
    • The target having two unique values 1 for apple and 0 for orange.
    • Weight is the weight of the fruit in grams.
    • Smooth is the smoothness of the fruit in the range of 1 to 10.

Now, let’s use the loaded dummy dataset to train a decision tree classifier.

Building the fruit classifier with decision tree algorithm

fruit_classifier = tree.DecisionTreeClassifier()
fruit_classifier.fit(fruit_data_set[["weight", "smooth"]], fruit_data_set["fruit"])

print ">>>>> Trained fruit_classifier <<<<<"
print fruit_classifier
  • Creating the decision tree classifier instance from the imported sci-kit learn tree class.
  • Using the loaded fruit data set features and the target to train the decision tree model.
  • Print the trained fruit classifier.

Script Output:

>>>>> Trained fruit_classifier <<<<<
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            presort=False, random_state=None, splitter='best')

Now let’s use the fruit classifier to predict the fruit type by giving the fruit features.

Predicting the fruit type from the trained classifier

# fruit data set 1st observation
test_features_1 = [[fruit_data_set["weight"][0], fruit_data_set["smooth"][0]]]
test_features_1_fruit = fruit_classifier.predict(test_features_1)
print "Actual fruit type: {act_fruit} , Fruit classifier predicted: {predicted_fruit}".format(
    act_fruit=fruit_data_set["fruit"][0], predicted_fruit=test_features_1_fruit)

# fruit data set 3rd observation
test_features_3 = [[fruit_data_set["weight"][2], fruit_data_set["smooth"][2]]]
test_features_3_fruit = fruit_classifier.predict(test_features_3)
print "Actual fruit type: {act_fruit} , Fruit classifier predicted: {predicted_fruit}".format(
    act_fruit=fruit_data_set["fruit"][2], predicted_fruit=test_features_3_fruit)

# fruit data set 8th observation
test_features_8 = [[fruit_data_set["weight"][7], fruit_data_set["smooth"][7]]]
test_features_8_fruit = fruit_classifier.predict(test_features_8)
print "Actual fruit type: {act_fruit} , Fruit classifier predicted: {predicted_fruit}".format(
    act_fruit=fruit_data_set["fruit"][7], predicted_fruit=test_features_8_fruit)

Created 3 test data sets and using the trained fruit classifier to predict the fruit type and comparing with the real fruit type.

Script Output: 

Actual fruit type: 1 , Fruit classifier predicted: [1]
Actual fruit type: 1 , Fruit classifier predicted: [1]
Actual fruit type: 0 , Fruit classifier predicted: [0]

The trained fruit classifier using the decision tree algorithm is accurately predicting the target fruit type for the given fruit features. You only know that the decision tree is predicting the target fruit type for the given fruit features in a black-box way and you don’t know what’s happing inside the black box.

To understand what happing inside the trained decision tree model and how it’s predicting the target class for the given features we need a visual representation of the trained decision tree classifier.

Why we need to visualize the trained decision tree

To answer the question of why we need to visualize the trained decision tree, I am going to show you the visual representation of the above fruit classifier.

Before I show you the visual representation of the trained decision tree classifier, have a look at the 3 test observations we considered for predicting the target fruit type from the fruit classifier.

Weight Smoot Fruit Classifier Predicted
170 9 1 1
180 10 1 1
130 2 0 0

The below image is the visual representation of the trained fruit classifier.

Decision tree visualization in Python with graphviz

Decision tree visualization in Python with Graphviz

In the next coming section, you are going to learn how to visualize the decision tree in Python with Graphviz.

Decision tree visualization explanation

The trained decision tree having the root node as fruit weight (x[0]). These conditions are populated with the provided train dataset.

If the weight is less than are equal to 157.5 go to the left node. If the weight is greater than 157.5 go to the right node. In fact, the right and left nodes are the leaf nodes as the decision tree considered only one feature (weight) is enough for classifying the fruit type.

The below pseudo-code can represent the above graph into simple if-else conditions.

def predict(weight, smoot):
    if weight <= 157.5:
        return 0
    else:
        return 1

Now if you pass the same 3 test observations we used to predict the fruit type from the trained fruit classifier you get to know why and how the trained decision tree predicting the fruit type for the given fruit features.

Visualize decision tree in python with Graphviz

I hope you the advantages of visualizing the decision tree. Now let’s move the key section of this article, Which is visualizing the decision tree in python with Graphviz.

You can visualize the trained decision tree in python with the help of Graphviz. Below are two ways to visualize the decision tree model.

  • Visualize the decision tree online
  • Visualize the decision tree as pdf

In both these cases, you need first convert the trained decision tree classifier into graphviz object. Later we use the converted graphviz object for visualization. So It’s better to know about the python graphviz before looking into the visualization part.

What is Graphviz

Graphviz is one of the visualization libraries. The greatness of graphviz is that it’s an open-source visualization library. Graphviz widely used in networking application were to visualize the connection between the switches hub and different networks. When it comes to machine learning used for decision tree and neural networks.

Now let’s look at how to visualize the decision tree with graphviz.

Visualize the decision tree online

To visualize the decision tree online first you need to convert the trained decision tree, in our case the fruit classifier into a file (txt is better). Later you can use the contents of the converted file to visualize online.

The below can will convert the trained fruit classifier into graphviz object and saves it into the txt file.

with open("fruit_classifier.txt", "w") as f:
    f = tree.export_graphviz(fruit_classifier, out_file=f)

If you are having the proper python machine learning packages set up in your system. After running the above code fruit_classifier.txt will be saved on your local system.

To visualize the decision tree, you just need to open the fruit_classifier.txt file and copy the contents of the file to paste in the graphviz web portal. Below is the address for the web portal.

graphviz web portal address: http://webgraphviz.com

You can see the below graphviz web portal.

graphviz web portal

graphviz web portal

Once the graphviz web portal opened. Remove the already presented text in the text box and paste the text in the created txt file and click on the generate-graph button.

For the modeled fruit classifier, we will get the below decision tree visualization.

decision tree visualization with graphviz

decision tree visualization with graphviz

Now let’s look at how to visualize the trained decision tree as pdf

Visualize the decision tree as pdf

The visualization of the trained decision tree as pdf will be the same as the above. the only change is instead on copy and pastes the contents of the converted txt file to the web portal, you will be converting it into a pdf file.

# converting into the pdf file
with open("fruit_classifier.dot", "w") as f:
    f = tree.export_graphviz(fruit_classifier, out_file=f)

The above code will convert the trained decision tree classifier into graphviz object and then store the contents into the fruit_classifier.dot file. Next to convert the dot file into pdf file you can use the below command.

dot -Tpdf fruit_classifier.dot -o fruit_classifier.pdf

To preview the created pdf file you can use the below command.

open -a preview fruit_classifier.pdf

You can get the complete code of this article on our Github account.

Follow us:

FACEBOOKQUORA |TWITTERGOOGLE+ | LINKEDINREDDIT FLIPBOARD | MEDIUM | GITHUB

I hope you like this post. If you have any questions, then feel free to comment below.  If you want me to write on one particular topic, then do tell it to me in the comments below.

Related Courses:

16 Responses to “visualize decision tree in python with graphviz

  • Sophia Yue
    5 years ago

    Hi,
    I enjoy reading your article and I am able to browse the tree online for Iris data. For some reason, there are a couple of typo errors under “What is Graphviz”.
    Below is the excerpt from the Internet:
    Graphviz is one of the visualization libray. The gratness of graphviz is that it’s a open source visualiztion library. Graphiz widely used in networking applicaiton where to visulaze the connection beteen the swiths hub and differnt networks. When it’s comes to machine leanring used for decision tree and newral networks.

    Below is my version for your reference.
    Graphviz is one of the visualization libraries. The greatness of Graphviz is that it’s an open source visualization library. Graphviz widely used in networking application to visualize the connection between the switch hub and different networks. When it comes to machine learning used for decision tree and neural networks.
    Best,
    Sophia

    • Thanks a lot, Sophia Yue,

      It’s surprising to me that, how those type errors came, I have correct all the typos in the article.

      Thanks and happy learning!.

  • deepak kansal
    5 years ago

    But When i want to import graphviz in pycharm it gives error in Source

    • Hi Deepak,

      Could you install graphviz in the same environment where you coding running hope it will resolve the issue.

      Thanks and happy learning!

  • Hi Saimadhu! I am a new starter of machine learning. Thank you for this helpful article.There is one things I am not sure and hope you can help me clarify! I remember that the training data set and the testing data set should always be different. Why do you use [[fruit_data_set[“weight”][0], fruit_data_set[“smooth”][0]]] to predict test_feature_1, which I assume is already loaded to the classifier. Could you please explain that? And also why there is double brackets outside [[fruit_data_set[“weight”][0], fruit_data_set[“smooth”][0]]]? I would really appreciate your help!

    • Hi Yahui,

      Thanks for your compliment. In the article, we are trying to predict how the build model is performing by passing the features to predict the target class, the double brackets are the proper syntax for getting single observation (single row)

      Thanks and happy learning

  • Thank for work done. It is nice. Great!!! Pls is there any mathematical or statistical step to back on random forest

  • Dear Ffion,
    I can’t see, how below command knows, which data we want to visualize with the model.
    with open(“fruit_classifier.txt”, “w”) as f:
    f = tree.export_graphviz(fruit_classifier, out_file=f)

    We only feed tree.export_graphviz with the name of the model, not with the data.

    • Hi Anna,

      Yes you are correct it seems like we haven’t used the data but we have stored all the trained model information into the fruit_classifier.txt later we are using the fruit_classifier.txt information to visualize the model.

  • please help when i applied this code it give this type of error:
    print (“Actual fruit type: {act_fruit} , Fruit classifier predicted: {predicted_fruit}”).format(

    AttributeError: ‘NoneType’ object has no attribute ‘format’

    • Hi Alsubari,
      Don’t use the extra brackets over the print. Do check the below code.

      print "Actual fruit type: {act_fruit} , Fruit classifier predicted: {predicted_fruit}".format(
      act_fruit=fruit_data_set["fruit"][7], predicted_fruit=test_features_8_fruit)

  • Hi, Could someone please explain what the number in the brackets refers to? (e.g. x[0])

    • Hi Ffion,

      In the article x[0] represents the first feature. In the example the feature is weight.

      • Hi Samadhu,

        Thank you for your response. I understand that the x would represent the feature, however when apply the tree to my code it starts with x[0], then the two options below state x[9]. Would this number refer to this split?

        • Hi Ffion,

          Basically, the x represents the list of features. In our case x[0] represents the first feature likewise other. We can relate this to how the decision tree splits the features.

Leave a Reply

Your email address will not be published. Required fields are marked *

>