In this tutorial, we’ll learn how to use tree-based models (random forests) to predict the values of both categorical and continuous variables from the values of other variables.
R
.My breakdown of how this all works is based on the fabulous overview in the Introduction to Statistical Learning text I’ve referenced throughout the class (Chapter 8) - so if you want more detail and code, head there!
Let’s get started!
You’ll need the following:
library(tidyverse)
library(randomForest)
At this point, you should understand the important distinction between models we use for inference and models for prediction. Unlike inferential models, such as the regression tools we used earlier in the class which are great for understanding relationships between variables, predictive models are more about predicting an event or outcome. This distinction is important… because we are less interested in understanding the fundamental relationships between variables, predictive models are often more difficult to explain using equations. These models often generate model structures that are so complex they are difficult for humans to understand (see this week’s video!). What we can understand, however, are the basic algorithms that identify the structure that best predicts the outcome of interest. In this tutorial, we’ll focus on a commonly-used set of predictive models: tree-based models. We’ll work through an example using a tree-based model for regression (prediction of a continuous variable) and for classification (prediction of a categorical variable that takes on only a few discrete values).
So how do tree-based models work? This image is a really helpful overview:
I like this visual overview. It shows, at a high level, how tree-based models work. Basically, the trees are finding the values of each variable that best help to distill the category you’re trying to predict. Instead of a few trees in this example, full models can built hundreds of trees to best predict the outcome of interest. Trees are basically a giant flow chart of yes/no questions. We try to find the tree structure that best predicts the outcome at the end (or nodes) of the trees (“Low Risk” versus “High Risk” in this example).
In this tutorial, we’re going to focus on one type of tree-based model, random forests. Random forests are a very popular tree-based model that uses bootstrap aggregation or bagging, which basically makes tons of predictions each using regression or classification trees, and then forms a final prediction based on the average of all of these trees. Crucially, subsets of predictors are randomly selected for each tree in this forest of trees. This reduces correlation between trees in the forest, improving prediction accuracy!1
Random forests is most often used for classification problems, so building models that best predict what category an observation falls into. In our case, we’re interested in predicting whether or not a particular place in the US is cultivated with our favorite crop ever, corn. This is really data I’m working with for a growing research project predicting how where we grow crops will change in a warmer world. This data.frame
contains a column called AP
that stands for absence/presence. The column takes a value of 1
when corn is present in that pixel and 0
when corn is not present. The other columns stand for the following things:
B10_MEANTEMP_WARM
- average temperature in the warmest quarter of the yearB4_TEMP_SEASONALITY
- standard deviation of temperature * 100IRR
- whether or not the pixel is irrigated (0
== no, 1
== yes)SLOPE
- slope of terrainELEVATION
- elevation of terrainB2_MEAN_DIURNAL_RANGE
B8_MEANTEMP_WET
- average temperature in the wettest quarterB18_PPT_WARMQ
- average precipitation in the warmest quarterB12_TOTAL_PPT
- total annual precipitationB15_PPT_SEASONALITY
- standard deviation of precipitation * 100T_CEC_SOIL
- cation exchange capacity of the topsoilT_OC
- topsoil organic carbonS_PH_H2O
- subsoil pHSoil data comes from the HWSD project. Irrigation from the MIRAD project. And topography from the National Elevation Dataset.
Note that all of the columns that start with B*...
are based on the list of biovars
build by the WorldClim team. If you’re interested in species distribution modeling or cool open-source climate futures data, check them out! There’s even an R
package out there (dismo
) that helps you play with this cool data.
ap <- readRDS("./data/corn_ap.RDS")
glimpse(ap)
## Rows: 4,961
## Columns: 14
## $ AP <fct> 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, ...
## $ B10_MEANTEMP_WARM <dbl> 22.56833, 21.99825, 24.21244, 22.40566, 19.48...
## $ B4_TEMP_SEASONALITY <dbl> 1078.1600, 916.1887, 836.2860, 1107.4468, 888...
## $ IRR <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, ...
## $ SLOPE <dbl> 0.09171792, 0.36307254, 0.50377464, 0.2550289...
## $ ELEVATION <dbl> 196, 254, 306, 392, 723, 533, 353, 452, 948, ...
## $ B2_MEAN_DIURNAL_RANGE <dbl> 11.220123, 12.530332, 12.551253, 12.100369, 1...
## $ B8_MEANTEMP_WET <dbl> 18.38585, 15.49067, 11.31021, 18.24522, 11.54...
## $ B18_PPT_WARMQ <dbl> 333.7619, 375.3810, 369.4762, 348.3810, 373.8...
## $ B12_TOTAL_PPT <dbl> 1019.2381, 1281.5238, 1487.5714, 960.9524, 13...
## $ B15_PPT_SEASONALITY <dbl> 62.61211, 40.68610, 44.64520, 75.86792, 39.95...
## $ T_CEC_SOIL <dbl> 19, 12, 6, 19, 12, 23, 11, 18, 7, 18, 26, 12,...
## $ T_OC <dbl> 1.68, 1.45, 0.98, 1.68, 1.45, 2.08, 0.82, 1.8...
## $ S_PH_H2O <dbl> 6.8, 5.2, 5.0, 6.8, 5.2, 7.4, 6.2, 6.8, 4.8, ...
Ok, so how do we build a predictive model of where corn is grown using random forests? First, make sure you’ve installed and loaded the randomForest
package. Then, remember that we need to split our data into training and testing data.2 We train our algorithm on the training data, and then see how well the model does at predicting the held-out testing data.
set.seed(1) # this ensures you generate the same random row numbers every time you run the code
# hold out 25% of the data
random_rn <- sample(nrow(ap), ceiling(nrow(ap)*.25)) # generate random row numbers
train <- ap[-random_rn,] # remove those random row numbers
test <- ap[random_rn,] # keep those random row numbers
# run rf, it's easy!
rf_ap <- randomForest(AP ~ ., data = train)
The last line of code actually runs the random forests algorithm. It’s easy to implement in R
, but note that there are lots of parameters you can set behind the scenes that can affect model performance (things like the number of trees). Check out Introduction to Statistical Learning for more on how to optimize these parameters. For now, we’re going to keep things simple. By running this little line of code, you’ve just built a fairly sophisticated model to predict where corn is grown.
rf_ap
##
## Call:
## randomForest(formula = AP ~ ., data = train)
## Type of random forest: classification
## Number of trees: 500
## No. of variables tried at each split: 3
##
## OOB estimate of error rate: 10.54%
## Confusion matrix:
## 0 1 class.error
## 0 1612 217 0.11864407
## 1 175 1716 0.09254363
When you print out the model, it tells you a few useful things. First, how many variables were tried at each split in the tree (yes this is a parameter you can change) and how many trees were built in the model (Number of trees
). This also reports the “out of bag” estimate of error rate, which is 10.54%
. Each tree built in the model has its own out-of-bag sample of data that was not used during construction. This essentially tells us the prediction accuracy on this out-of-bag sample data.
Finally, this returns our confusion matrix, which tells us how often the model got things wrong (guessed a 0
when the correct value was 1
, or vice versa). We’ll use this confusion matrix on our test data to come up with an estimate of model performance. To do this, we first need to use our new model to predict our test data using the predict()
function:
preds <- predict(rf_ap, test, type = "class")
head(preds)
## 14389 11203 18797 421 14956 17496
## 1 0 1 0 1 1
## Levels: 0 1
preds
now contains a list of zeros and ones for each row in the test
data.frame
. Here, the predict()
function basically predicts the value of AP
(absence or presence) based on the values of the climate, soil, and topographic predictors in the test
data. By setting type = 'class'
we tell the predict()
function to generate its best guess of the value of AP
(0 or 1). We can also set type = 'prob'
which is cool because it will generate a probability of that row being corn! Try it!
So how do we assess predictive performance on the test
data?
test$PREDICTIONS <- preds
cm <- table(test$PREDICTIONS, test$AP)
cm
##
## 0 1
## 0 552 61
## 1 62 566
Here’s our confusion matrix on the test data. It shows, for example, that for 552
rows the model correctly guessed zero (no corn). We can use this confusion matrix to compute the error rate as follows:
error <- (cm[1,1] + cm[2, 2])/nrow(test)
error
## [1] 0.9008864
This is basically pulling out the number of times we get things right divided by the total number of observations! So our model got things right ~90% of the time. Not bad!
Another cool thing we can pull from this model is called a variable importance plot. This plot tells us how each variable contributes to the prediction of the outcome, so variables with a higher variable importance contribute more to predicting the outcome of interest. CAUTION. This does not mean that they cause the outcome, just that as the predictor values vary, they help predict variation in the outcome.3
imp <- as.data.frame(varImpPlot(rf_ap)) # this automatically creates a figure
imp$varnames <- rownames(imp) # row names to column
imp
## MeanDecreaseGini varnames
## B10_MEANTEMP_WARM 292.30805 B10_MEANTEMP_WARM
## B4_TEMP_SEASONALITY 324.41407 B4_TEMP_SEASONALITY
## IRR 73.57496 IRR
## SLOPE 99.42673 SLOPE
## ELEVATION 104.82330 ELEVATION
## B2_MEAN_DIURNAL_RANGE 136.54736 B2_MEAN_DIURNAL_RANGE
## B8_MEANTEMP_WET 143.27590 B8_MEANTEMP_WET
## B18_PPT_WARMQ 130.68100 B18_PPT_WARMQ
## B12_TOTAL_PPT 166.10387 B12_TOTAL_PPT
## B15_PPT_SEASONALITY 97.68377 B15_PPT_SEASONALITY
## T_CEC_SOIL 97.82149 T_CEC_SOIL
## T_OC 36.13708 T_OC
## S_PH_H2O 154.36127 S_PH_H2O
Here, the varImpPlot()
computes MeanDeceaseGini
which is the mean decrease in the Gini coefficient. Scale of this value is irrelevant; only relative values matter. It’s basically telling us about how average model performance changes every time that variable is chosen to split a node. I really like this overview of how the Gini coefficient works.
We can use ggplot()
to create a better visualization of importance:
ggplot(imp, aes(x=reorder(varnames, MeanDecreaseGini), y=MeanDecreaseGini)) +
geom_point() +
geom_segment(aes(x=varnames,xend=varnames,y=0,yend=MeanDecreaseGini)) +
ylab("Mean decrease Gini") +
xlab("") +
coord_flip() +
theme_minimal()
So here, we see that temperature seasonality and temperatures during the warmest quarter strongly predict where corn is grown. This makes sense! But remember, this doesn’t necessarily mean these things have a causal relationship with corn production, just that they help predict where corn is grown.
The language here might be a bit confusing, but random forests for regression just means we’re using random forests to predict a continuous variable (like corn yields). Hey, we have data on corn yields! Let’s load it and RF it! Luckily, most of the code to implement this is the same, we just have to think differently about how we assess error. Instead of predicting whether or not we got the classification correct, we’re not looking essentially at residuals, so how far off from the actual yield value were our predictions. We can use our old friend mean squared error to look at this.
corn_yield <- readRDS("./data/corn.RDS")
# for RF to work, we need to remove the weird identifier variables in the data.frame that don't go in the RF regression
corn_yield <- corn_yield %>% select(-c(GEOID, STATE_NAME, COUNTY_NAME))
set.seed(1)
random_rn <- sample(nrow(corn_yield), ceiling(nrow(corn_yield)*.25))
train <- corn_yield[-random_rn,]
test <- corn_yield[random_rn,] #
rf_ap <- randomForest(CORN_YIELD ~ ., data = train)
We can use the same predict()
function, but now to predict CORN_YIELD
rather than the binary indicator we used above. Now, our confusion matrix approach to assessing error doesn’t really make sense any more. Instead, we need to look at how far off our predicted yields are from actual yields.
preds <- predict(rf_ap, test, type = "response")
test$PREDICTIONS <- preds
head(test %>% select(PREDICTIONS, CORN_YIELD))
## PREDICTIONS CORN_YIELD
## 1107 208.4460 211.9
## 9093 131.0925 146.4
## 5505 130.2069 142.0
## 2628 151.5031 147.0
## 5768 121.6851 141.2
## 1758 105.3472 42.3
We can do this using mean squared error! Remember, this basically gives us a sense of how big our residuals are, or how far off our model is from the real values:
MSE <- mean((test$CORN_YIELD - test$PREDICTIONS)^2)
print(MSE)
## [1] 494.1617
Our goal here is to minimize the mean squared error, so as we run different models, we try to reduce how far off our modeled values are from actual values.
Now, let’s use variable importance plots to look at the variables most predictive of yields (DRUM ROLL):
ggplot(imp, aes(x=reorder(varnames, IncNodePurity), y=IncNodePurity)) +
geom_point() +
geom_segment(aes(x=varnames,xend=varnames,y=0,yend=IncNodePurity)) +
ylab("Increase in node purity") +
xlab("") +
coord_flip() +
theme_minimal()
Note here that we’re assessing variable importance with a different metric here, the increase in node purity. This is like the Gini-based metric above, and is calculated using the reduction in the sum of squared errors when a new variable is chosen in a split. What’s cool here is that we see that YEAR
strongly predicts yield changes, which means there’s likely something changing through time we haven’t included in our model that affects yield (technology? markets? prices?).
Love love love this overview with cool visualizations of how random forests works.↩
Technically you want to split into three chunks, one for training, one for validation of training models, and a final independent data you only use for testing your final model.↩
You have to be really careful interpreting these plots. Variable importance becomes really complex with highly collinear variables.↩