Classification and Regression Tree (CART) analysis is a very common modeling technique used to make prediction on a variable (Y), based upon several explanatory variables, \(X_1, X_2, ...X_p\). The term Classification Tree is used when the response variable is categorical, while Regression Tree is used when the response variable is continuous. CART analysis is very common because it works well for a wide variety of data sets and is easy to interpret. In addition, it does not require model assumptions (i.e. the variables do not need to follow any distributional patterns).
Data: In this tutorial we will start by using a classic data set often called Fisher’s Iris data. This data set is one of the standard data sets automatically loaded within R. It includes information on 150 flowers that are from three species of irises (setosa, versicolor, virginica). There are four explanatory variables, representing the length and width of the flower sepal and petal. Our goal is to use the explanatory variables to classify each iris into the proper species group.
library(ggformula)
library(tree)
library(dplyr)
head(iris)
gf_point(Sepal.Length ~ Sepal.Width, data = iris, color = ~ Species) %>%
gf_labs(title = "Figure 1: Scatterplot of Iris Data")
Questions:
ggformula
package, you can use gf_vline(xintercept = ___)
to draw these lines.In the previous questions you:
To create a tree diagram using only the Petal.Width variable, we utilize the tree()
function in the tree
package. As done in Questions 1 - 3, we will create mathematical rules to divide the predictor space into distinct, non-overlapping regions. Then, every observation within a particular region is given the same predicted value.
The technique to create these classification trees is often called recursive binary splitting. This is called binary because we are splitting the predictor space into two pieces and called recursive because we repeat this process multiple times. In essence, classification trees are created by:
#Create a tree diagram with Species as the predicted value and Petal.Width as the explanatory variable.
tree1 <- tree(Species ~ Petal.Width, data = iris)
plot(tree1)
text(tree1)
The classification tree consists of nodes that some say make it look like an upside down tree. Internal nodes have exactly one incoming edge and then the data are split based upon a decision (a mathematical rule to split the predicted response). All other nodes are called leaves (also known as terminal nodes
or decision nodes
).
tree()
function reads the iris data and determines the first splitting rule to be at Petal.Width = 0.8. We predict that when Petal.Width < 0.8, the species will be setosa
.versicolor
when 0.8 < Petal.Width < 1.75 and any point with Petal.Width > 1.75 will be classified as virginica
.There are multiple ways to assess the accuracy of decision trees. One of the most straight forward approaches is to simply count the number of points that were misclassified. Here we use the summary()
function to calculate that our misclassification rate = 6/150 = 4%.
summary(tree1)
##
## Classification tree:
## tree(formula = Species ~ Petal.Width, data = iris)
## Number of terminal nodes: 4
## Residual mean deviance: 0.2404 = 35.09 / 146
## Misclassification error rate: 0.04 = 6 / 150
Review the graph you created in Question 2 to verify that, for any vertical lines selected, at least 6 points will be misclassified. Notice that there are multiple rules that could be used that would give an identical misclassification rate. For example, creating our first split at Petal.Width = 0.6 or Petal.Width = 0.9 would have the same misclassification rate.
To visualize how to use two explanatory variables to create a classification tree, go to the CART Shiny App. Try to determine the best splitting rules for the iris data. In the app, use the Iris1
data set. Select the splitting rule for the x-axis by using the x slider. To split the data on the y-axis, select the appropriate box and move the corresponding y slider. Note that the accuracy rate and optimized tree can also be shown.
The tree()
function can easily be modified to incorporate more than one explanatory variable. Below we use ~ Petal.Length + Petal.Width
to identify specific explanatory variables to include in our tree. We can also use ~ .
to construct a classification tree which could potentially use all the explanatory variables. Note that the tree()
function doesn’t necessarily use all the explanatory variables. Instead, this algorithm selects only variables that are useful in the classification process.
tree2 <- tree(Species ~ Petal.Length + Petal.Width, data = iris)
plot(tree2)
text(tree2)
treeall <- tree(Species ~ . ,data = iris)
plot(treeall)
text(treeall)
Questions:
tree2
model, if Petal.Width = 1.5 and Petal.Length = 5, what would you classify the species to be?summary()
function on both the tree2
and treeall
models and give their misclassification rate. Compare both classification trees and why you may prefer to use the tree2
model instead of the treeall
model.The trees shown above can quickly get very complex. We are often interested in limiting the number of nodes while still getting accurate predictions. One way to measure accuracy is to simply count misclassification rate. We then create a table of actual versus predicted values to determine our misclassification rate.
# Create a summary table (called a confusion matrix)
#output is the class instead of the probability
tree.pred <- predict(tree1, iris, type = "class")
table(tree.pred, iris$Species)
##
## tree.pred setosa versicolor virginica
## setosa 50 0 0
## versicolor 0 49 5
## virginica 0 1 45
Questions:
tree1
model, how many misclassifications were made? How many virginica species were misclassified as versicolor?treeall
model. How many versicolor were misclassified as virginica?The misclassification rate is a straight-forward measure of the accuracy of our model predictions. However, with large data sets, other measures such as the Gini impurity
or cross entropy
, are typically used when developing classification trees (James et. al., 2013).
The previous sections describe how to create classification trees. However, the trees we created did not provide very much useful information, since we already know the actual species for all 150 plants in our data set. Typically the goal of CART is to make predictions, i.e. when we have the explanatory variables for new plants and we want to be able to accurately predict the unknown response.
To test a model’s ability to make accurate predictions with new data, researchers typically fit a model using only a sample of their original data. This reduced data set is often called a training data set. After a model is fit using their training data set, researchers test their model on the remaining data, often called the testing data set or holdout data set. This technique is called cross validation. Cross validation is essential for properly evaluating the performance of a model. There are many types of cross validation techniques based upon how the training data is selected. We demonstrate one example of the cross validation process below.
# we use set.seed to ensure that the same random sample is selected every time. This is only for demonstration purposes to ensure everyone will get exactly the same answer with this code.
set.seed(123)
# Randomly sample 100 rows from the iris data set. These 100 rows will be used as our training data set.
train <- sample(1:nrow(iris), 100)
# Create a classification tree using only the training data
tree.iris <- tree(Species ~ ., iris, subset = train)
plot(tree.iris)
text(tree.iris)
# Use the tree.iris model on the testing data, iris[-train],
tree.pred <- predict(tree.iris, newdata = iris[-train,], type = "class")
# Create a summary table of our results based on the testing data
with (iris[-train,], table(tree.pred, Species))
## Species
## tree.pred setosa versicolor virginica
## setosa 14 0 0
## versicolor 0 16 1
## virginica 0 3 16
Questions
In the above example, we used two thirds of our data (100 rows) for training and one third of our data (50 rows) for testing. There is no exact rule for the percentage to use for the training data set. Many researchers suggest using 80% of your data for training, while others suggest using between 50% to 90% of the data for training. This decision will depend on the size and complexity of your data. In general, more training data will allow the researcher to create better models and more testing data is useful to accurately evaluate those models. It is common to have thousands of rows of data when creating classification trees. Data sets with only a few hundred rows are rarely useful, especially if there are a large number of explanatory variables.
Whenever researchers create models, they must consider trade-offs between cost (loss in model accuracy) and complexity. For any data set, it is possible to construct a series of mathematical rules that will create a model that is 100% accurate when the response variable is provided within the data. However, that model typically
Models are over fit when they are too focused on the current data set. That is, in order to increase the accuracy of their model, complex models may sometimes be built to account for random noise in the data. Consequently, an over-fit model produces accurate classifications for the training set, but makes inaccurate predictions for a new data, such as the test set.
To create a parsimonious model (a model that carefully selects a relatively small number of the most useful explanatory variables), researchers create rules to stop making additional nodes, known as stopping rules. For example, if a model is too complex (the tree has too many nodes), researchers often want to “prune” the tree. Below we use two functions, cv.tree()
and prune.tree()
, to demonstrate a cross-validation method to create stopping rules.
cv.iris <- cv.tree(tree.iris, FUN = prune.misclass)
plot(cv.iris)
In general, we want to select a tree size that corresponds to the smallest amount of overall errors when using cross validation. The plot of cv.iris
tells us that the size (the number of terminal nodes) on the x axis and the misclassification rate on the y axis. In our example, we see that 3 or 4 terminal nodes corresponded to the lowest error rates. In other words, a model with 3 terminal nodes is the simplest model that still has low misclassification rates. Note that while we used the misclassification rate, FUN = prune.misclass
to make decisions on pruning this tree, other measures of accuracy are often used. In the code below we create a tree where the number of terminal nodes is set to 3.
prune.iris <- prune.misclass(tree.iris, best = 3)
plot(prune.iris)
text(prune.iris)
Important Note The above code uses cross validation methods on the training data to determine the appropriate size of the tree. In other words, the training data is again subsetted into new training and testing data sets, trees are created, and error rates are calculated. This process occurs multiple times, each with new randomly selected training and testing data sets, to give a measure of the expected error rate for each tree size. In our example, this pruning process is based on our original 100 rows of training data. Testing data should never be used in the creation of a model.
After we have developed a final model based on the training data, we then use the testing data one time to get a measure of the accuracy of our model. If the testing data is used anywhere in the process of model development, the confusion matrix is not a true measure of our ability to accurately make predictions on new data.
Below we create a confusion matrix to evaluate the accuracy of our final pruned model, prune.iris
.
tree.pred <- predict(prune.iris, iris[-train,], type = "class")
with (iris[-train,], table(tree.pred, Species))
## Species
## tree.pred setosa versicolor virginica
## setosa 14 0 0
## versicolor 0 16 1
## virginica 0 3 16
Important considerations when using CART
rpart
package provides a much faster implementation than the tree package, and the package party
creates much better visualizations. If you are conducting an actual research project using CART analysis, please consider one of these packages.