caret
You may need to install caret
with install.packages("caret")
before running the next block.
library(caret)
Predictive modeling, as the name implies, is mainly concerned with making good predictions without worrying about making inferences about how a population works (as in causal analysis). Remember that “correlation does not imply causation”, but correlation can help us make useful predictions.
caret
Classification and Regression TrainingThe caret
package provides a uniform interface for fitting 237 different models.
We’ll use the built-in dataset Sacramento
, containing data on 932 home sales in Sacramento, CA over a five-day period.
data(Sacramento)
str(Sacramento)
## 'data.frame': 932 obs. of 9 variables:
## $ city : Factor w/ 37 levels "ANTELOPE","AUBURN",..: 34 34 34 34 34 34 34 34 29 31 ...
## $ zip : Factor w/ 68 levels "z95603","z95608",..: 64 52 44 44 53 65 66 49 24 25 ...
## $ beds : int 2 3 2 2 2 3 3 3 2 3 ...
## $ baths : num 1 1 1 1 1 1 2 1 2 2 ...
## $ sqft : int 836 1167 796 852 797 1122 1104 1177 941 1146 ...
## $ type : Factor w/ 3 levels "Condo","Multi_Family",..: 3 3 3 3 3 1 3 3 1 3 ...
## $ price : int 59222 68212 68880 69307 81900 89921 90895 91002 94905 98937 ...
## $ latitude : num 38.6 38.5 38.6 38.6 38.5 ...
## $ longitude: num -121 -121 -121 -121 -121 ...
First we’ll split our dataset into two parts:
training
dataset we’ll use to fit our models.testing
dataset we’ll set aside for comparison after fitting the models. This helps avoid overfitting the training
set by examining how it fits unseen data.set.seed(12345)
train.select <- createDataPartition(Sacramento$type, p = .8, list = FALSE)
training <- Sacramento[ train.select,]
testing <- Sacramento[-train.select,]
Many of the more complicated models we can fit with caret
need to determine optimal settings for various “tuning” parameters.
We can use Repeated k-fold Cross Validation (among other methods) to determine the best values for the tuning parameters within default ranges. In practice you may want to supply your own grid of possible tuning parameter values. Read more here.
fitControl <- trainControl(## 5-fold Cross Validation
method = "repeatedcv",
number = 5,
## repeated ten times
repeats = 10)
The train
function is used to fit models. The full list of models is available here. You can get more informtation about a model and its tuning parameters with getModelInfo(<model name>)
.
We’ll fit several example models all attempting to predict home price from all of the other variables except zip code and city (these have many unique values and complicate the models).
#Ordinary Least Squares
set.seed(8947) # ensures paired resampling later
lmfit <- train(price ~ .-zip-city, data = training,
method = "lm",
trControl = fitControl)
#Robust Linear Model
set.seed(8947) # ensures paired resampling later
robustfit <- train(price ~ .-zip-city, data = training,
method = "rlm",
trControl = fitControl,
verbose = FALSE)
#Random Forests
set.seed(8947) # ensures paired resampling later
rffit <- train(price ~ .-zip-city, data = training,
method = "ranger",
trControl = fitControl,
verbose = FALSE)
#XGBoost (a refinement of Random Forests) - this is the slowest model!
set.seed(8947) # ensures paired resampling later
xgbfit <- train(price ~ .-zip-city, data = training,
method = "xgbTree",
trControl = fitControl,
verbose = FALSE)
#Support Vector Machine with Linear Kernel
set.seed(8947) # ensures paired resampling later
svmfit <- train(price ~ .-zip-city, data = training,
method = "svmLinear",
trControl = fitControl,
verbose = FALSE)
Notes
caret
is just a wrapper for fitting models - it does not include functions to fit many of these models. When fitting a model you may see the following message:1 package is needed for this model and is not installed. (<package-name>). Would you like to try to install it now?
1: yes
2: no
Press 1 and hit enter to install the package and fit the model. Packages needed are listed in the Available models list.
We’ve intentionally made sure a few things are consistent across our models to make comparisons easier:
price
.To ensure we can compare between models with resampling:
trControl=fitControl
setting.train
call.
set.seed(8947)
; the number 8947
is unimportant, it just needs to be consistent.The resamples
function considers the models against datasets simulated by sampling from the training set with replacement. You may be familiar with the related concept of “bootstrapping”.
caret
gives us three different indices to compare these models:
These track how well the model fits the data in different ways. Without getting into the details of how they’re calculated, we’ll use the rules of thumb that:
Note: caret
provides different metrics (Kappa and accuracy) for classification (i.e. categorical outcomes) tasks.
results <- resamples(list("OLS"=lmfit,"Random.Forest"=rffit,
"Robust.LM"=robustfit,"SVM"=svmfit,
"xgbTree"=xgbfit))
summary(results)
##
## Call:
## summary.resamples(object = results)
##
## Models: OLS, Random.Forest, Robust.LM, SVM, xgbTree
## Number of resamples: 50
##
## MAE
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## OLS 52269.18 55485.30 58818.95 58441.30 60893.86 65450.93 0
## Random.Forest 49279.82 53766.83 55544.30 55867.46 57880.75 63629.87 0
## Robust.LM 51392.84 55473.35 58250.71 58233.69 60540.08 65537.82 0
## SVM 51448.00 55099.35 58216.62 58035.29 60747.92 65054.30 0
## xgbTree 48263.44 53865.97 55693.54 56146.84 58456.72 63451.97 0
##
## RMSE
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## OLS 65659.34 73381.47 79659.30 80078.43 84591.39 97446.75 0
## Random.Forest 64868.09 73644.40 77216.12 77427.25 81268.66 92233.42 0
## Robust.LM 64088.97 72812.25 79191.32 79970.74 84451.94 97072.50 0
## SVM 64760.41 73762.25 79001.49 80290.13 84719.63 97588.28 0
## xgbTree 62628.90 73273.09 77381.98 77926.27 82772.19 92879.58 0
##
## Rsquared
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## OLS 0.4456435 0.6152435 0.6523681 0.6394884 0.6729048 0.7359156
## Random.Forest 0.5426827 0.6475436 0.6695347 0.6635814 0.6897644 0.7462610
## Robust.LM 0.4513706 0.6185191 0.6554366 0.6416072 0.6758294 0.7365679
## SVM 0.4508157 0.6173829 0.6555446 0.6416658 0.6755272 0.7400575
## xgbTree 0.5183405 0.6376842 0.6667406 0.6600252 0.6950497 0.7487522
## NA's
## OLS 0
## Random.Forest 0
## Robust.LM 0
## SVM 0
## xgbTree 0
We can also present these results in graphical form:
bwplot(results,scales=list(relation="free"))
Remember that these results are across a number of resamples, hence the boxplots and not a single value per model!
The random forest and xgbTree seem to be doing well here, but it’s not clear that one is clearly outperforming the other.
Let’s revisit our testing
data.
We can use our models to generate predictions with predict
, then compare their performance with postResample
.
lm.test <- predict(lmfit,testing)
robust.test <- predict(robustfit,testing)
rf.test <- predict(rffit,testing)
xgb.test <- predict(xgbfit,testing)
svm.test <- predict(svmfit,testing)
train.results <- rbind(
"LM"=postResample(pred=lm.test,obs=testing$price),
"Robust"=postResample(pred=robust.test,obs=testing$price),
"Random Forest"=postResample(pred=rf.test,obs=testing$price),
"SVM"=postResample(pred=svm.test,obs=testing$price),
"xgbTree"=postResample(pred=xgb.test,obs=testing$price)
)
print(train.results)
## RMSE Rsquared MAE
## LM 75489.91 0.6212181 53912.04
## Robust 74880.00 0.6205356 52911.52
## Random Forest 72432.66 0.6525539 51233.31
## SVM 75551.70 0.6187436 52421.61
## xgbTree 72000.05 0.6566042 49903.68
Which model seems to do best?