Decision Tree Classifier implementation in R

Decision Tree Classifier in R

Decision Tree Classifier in R

Decision Tree Classifier implementation in R

The decision tree classifier is a supervised learning algorithm which can use for both the classification and regression tasks. As we have explained the building blocks of decision tree algorithm in our earlier articles. Now we are going to implement Decision Tree classifier in R using the R machine learning caret package.

To get more out of this article, it is recommended to learn about the decision tree algorithm. If you don’t have the basic understanding on Decision Tree classifier, it’s good to spend some time on understanding how the decision tree algorithm works.

Decision Tree Classifier implementation in R Click To Tweet

Why use the Caret Package

To work on big datasets, we can directly use some machine learning packages. The developer community of R programming language has built the great packages Caret to make our work easier. The beauty of these packages is that they are well optimized and can handle maximum exceptions to make our job simple. We just need to call functions for implementing algorithms with the right parameters.

Caret Package Installation

The R programming machine learning caret package( Classification And REgression Training) holds tons of functions that helps to build predictive models. It holds tools for data splitting, pre-processing, feature selection, tuning and supervised – unsupervised learning algorithms, etc. It is similar to the sklearn library in python.

For using it, we first need to install it. Open R console and install it by typing below command:

The installed caret package provides us direct access to various functions for training our model with different machine learning algorithms like Knn, SVM, decision tree, linear regression, etc.

Cars Evaluation Data Set Description

The Cars Evaluation data set consists of 7 attributes, 6 as feature attributes and 1 as the target attribute. All the attributes are categorical. We will try to build a classifier for predicting the Class attribute. The index of target attribute is 7th.

1 buying vhigh, high, med, low
2 maint vhigh, high, med,low
3 doors 2, 3, 4, 5 , more
4 persons 2, 4, more
5 lug_boot small, med, big.
6 safety low, med, high
7 Car Evaluation –  Target Variable  unacc, acc, good, vgood

 

The above table shows all the details of data.

Car Evaluation Problem Statement:

To model a classifier for evaluating the acceptability of car using its given features.

Decision Tree classifier implementation in R with Caret Package

R Library import

For implementing Decision Tree in r, we need to import “caret” package & “rplot.plot”. As we mentioned above, caret helps to perform various tasks for our machine learning work. The “rplot.plot” package will help to get a visual plot of the decision tree.

 

In case if you face any error while running the code. Frist install the package rplot.plot using the command install.packages(“rpart.plot”)

Data Import

For importing the data and manipulating it, we are going to use data frames. First of all, we need to download the dataset. You can download the dataset from here.  All the data values are separated by commas.  After downloading the data file, you need to set your working directory via console else save the data file in your current working directory.

You can get the path of your current working directory by running getwd() command in R console. If you wish to change your working directory then the setwd(<PATH of  New Working Directory>) can complete our task.

For importing data into an R data frame, we can use read.csv() method with parameters as a file name and whether our dataset consists of the 1st row with a header or not. If a header row exists then, the header should be set TRUE else header should set to FALSE.

For checking the structure of data frame we can call the function str() over car_df:

 

The above output shows us that our dataset consists of 1728 observations each with 7 attributes.

To check top 5-6 rows of the dataset, we can use head().

All the features are categorical, so normalization of data is not needed.

Data Slicing

Data slicing is a step to split data into train and test set. Training data set can be used specifically for our model building. Test dataset should not be mixed up while building model. Even during standardization, we should not standardize our test set.

 

The set.seed() method is used to make our work replicable. As we want our readers to learn concepts by coding these snippets. To make our answers replicable, we need to set a seed value. During partitioning of data, it splits randomly but if our readers will pass the same value in the set.seed() method. Then we both will get identical results.

The caret package provides a method createDataPartition() for partitioning our data into train and test set. We are passing 3 parameters. The “y” parameter takes the value of variable according to which data needs to be partitioned. In our case, target variable is at V7, so we are passing car_df$V7 (heart data frame’s V7 column).

The “p” parameter holds a decimal value in the range of 0-1. It’s to show that percentage of the split. We are using p=0.7. It means that data split should be done in 70:30 ratio. The “list” parameter is for whether to return a list or matrix. We are passing FALSE for not returning a list. The createDataPartition() method is returning a matrix “intrain” with record’s indices.

By passing values of intrain, we are splitting training data and testing data.
The line training <- car_df[intrain,]  is for putting the data from data frame to training data. Remaining data is saved in the testing data frame, testing <- car_df[-intrain,]

For checking the dimensions of our training data frame and testing data frame, we can use these:

 

Preprocessing & Training

Preprocessing is all about correcting the problems in data before building a machine learning model using that data. Problems can be of many types like missing values, attributes with a different range, etc.

To check whether our data contains missing values or not, we can use anyNA() method. Here, NA means Not Available.

Since it’s returning FALSE, it means we don’t have any missing values.

Dataset summarized details

For checking the summarized details of our data, we can use the summary() method. It will give us a basic idea about our dataset’s attributes range.

Training the Decision Tree classifier with criterion as information gain

Caret package provides train() method for training our data for various algorithms. We just need to pass different parameter values for different algorithms. Before train() method, we will first use trainControl() method. It controls the computational nuances of the train() method.

We are setting 3 parameters of trainControl() method. The “method” parameter holds the details about resampling method. We can set “method” with many values like  “boot”, “boot632”, “cv”, “repeatedcv”, “LOOCV”, “LGOCV” etc. For this tutorial, let’s try to use repeatedcv i.e, repeated cross-validation.

The “number” parameter holds the number of resampling iterations. The “repeats ” parameter contains the complete sets of folds to compute for our repeated cross-validation. We are using setting number =10 and repeats =3. This trainControl() methods returns a list. We are going to pass this on our train() method.

Before training our Decision Tree classifier, set.seed().

For training Decision Tree classifier, train() method should be passed with “method” parameter as “rpart”. There is another package “rpart”, it is specifically available for decision tree implementation. Caret links its train function with others to make our work simple.

We are passing our target variable V7. The “V7~.” denotes a formula for using all attributes in our classifier and V7 as the target variable. The “trControl” parameter should be passed with results from our trianControl() method.

You can check the documentation rpart by typing ?rpart . We can use different criterions while splitting our nodes of the tree.

To select the specific strategy, we need to pass a parameter “parms” in our train() method. It should contain a list of parameters for our rpart method. For splitting criterions, we need to add a “split” parameter with values either “information” for information gain & “gini” for gini index. In the above snippet, we are using information gain as a criterion.

Trained Decision Tree classifier results

We can check the result of our train() method by a print dtree_fit variable. It is showing us the accuracy metrics for different values of cp. Here, cp is complexity parameter for our dtree.

Plot Decision Tree

We can visualize our decision tree by using prp() method.

The decision tree visualization shown above indicates its structure. It shows the attribute’s selection order for criterion as information gain.

Prediction

Now, our model is trained with cp = 0.01123596. We are ready to predict classes for our test set. We can use predict() method. Let’s try to predict target variable for test set’s 1st record.

For our 1st record of testing data classifier is predicting class variable as “unacc”.  Now, its time to predict target variable for the whole test set.

The above results show that the classifier with the criterion as information gain is giving 83.72% of accuracy for the test set.

Training the Decision Tree classifier with criterion as gini index

Let’s try to program a decision tree classifier using splitting criterion as gini index. It is showing us the accuracy metrics for different values of cp. Here, cp is complexity parameter for our dtree.

 Plot Decision Tree

We can visualize our decision tree by using prp() method.

Prediction

Now, our model is trained with cp = 0.01123596. We are ready to predict classes for our test set.
Now, it’s time to predict target variable for the whole test set.

The above results show that the classifier with the criterion as gini index is giving 86.05% of accuracy for the test set. In this case, our classifier with criterion gini index is giving better results.

Follow us:

FACEBOOKQUORA |TWITTER| GOOGLE+ | LINKEDINREDDIT | FLIPBOARD | MEDIUM | GITHUB

I hope you like this post. If you have any questions, then feel free to comment below.  If you want me to write on one particular topic, then do tell it to me in the comments below.

Related Courses:

Do check out unlimited data science courses

Title & links Details What You Will Learn
Machine Learning A-Z: Hands-On Python & R In Data Science

Students Enrolled :: 19,359

Course Overall Rating:: 4.6 

  • Master Machine Learning on Python & R
  • Make robust Machine Learning models.
  • Handle specific topics like Reinforcement Learning, NLP and Deep Learning.
  • Build an army of powerful Machine Learning models and know how to combine them to solve any problem.
R Programming A-Z: R For Data Science With Real Exercises!

Students Enrolled :: 12,001

Course Overall Rating:: 4.6

  • Program in R at a good level.
  • Learn the core principles of programming.
  • Understand the Normal distribution.
  • Practice working with statistical, financial and sport data in R
Data Mining with R: Go from Beginner to Advanced!
Students Enrolled :: 2,380

Course Overall Rating:: 4.2

  • Use R software for data import and export, data exploration and visualization, and for data analysis tasks, including performing a comprehensive set of data mining operations.
  • Apply the dozens of included “hands-on” cases and examples using real data and R scripts to new and unique data analysis and data mining problems.
  • Effectively use a number of popular, contemporary data mining methods and techniques in demand by industry including: (1) Decision, classification and regression trees (CART); (2) Random forests; (3) Linear and logistic regression; and (4) Various cluster analysis techniques.

 

5 Responses to “Decision Tree Classifier implementation in R

  • Demangel
    2 months ago

    Hi,

    Thanks for the course. There’s just a little mistake about the package : it’s rpart.plot instead of rplot.plot.
    Best regards.
    Eric

  • Getting the error… there is no package called ‘rplot.plot’.

Trackbacks & Pings

Leave a Reply

Your email address will not be published. Required fields are marked *