Machine Learning with R: An Irresponsibly Fast Tutorial

Machine learning without the hard stuff


As I said in Becoming a data hacker, R is an awesome programming language for data analysts, especially for people just getting started. In this post, I will give you a super quick, very practical, theory-free, hands-on intro to writing a simple classification model in R, using the caret package. If you want to skip the tutorial, you can find the R code here. Quick note: if the code examples look weird for you on mobile, give it a try on a desktop (you can’t do the tutorial on your phone, anyway!).

The caret package

One of the biggest barriers to learning for budding data scientists is that there are so many different R packages for machine learning. Each package has different functions for training the model, different functions for getting predictions out of the model and different parameters in those functions. So in the past, trying out a new algorithm was often a huge ordeal. The caret package solves this problem in an elegant and easy-to-use way. Caret contains wrapper functions that allow you to use the exact same functions for training and predicting with dozens of different algorithms. On top of that, it includes sophisticated built-in methods for evaluating the effectiveness of the predictions you get from the model. I recommend that you do all of your machine-learning work in caret, at least as long as the algorithm you need is supported. There's a nice little intro paper to caret here.

The Titanic dataset

Most of you have heard of a movie called Titanic. What you may not know is that the movie is based on a real event, and Leonardo DiCaprio was not actually there. The folks at Kaggle put together a dataset containing data on who survived and who died on the Titanic. The challenge is to build a model that can look at characteristics of an individual who was on the Titanic and predict the likelihood that they would have survived. There are several useful variables that they include in the dataset for each person:

  • pclass: passenger class (1st, 2nd, or 3rd)
  • sex
  • age
  • sibsp: number of Siblings/Spouses Aboard
  • parch: number of Parents/Children Aboard
  • fare: how much the passenger paid
  • embarked: where they got on the boat (C = Cherbourg; Q = Queenstown; S = Southampton)

So what is a classification model anyway?

For our purposes, machine learning is just using a computer to “learn” from data. What do I mean by “learn?” Well, there are two main different possible types of learning:

  • supervised learning: Think of this as pattern recognition. You give the algorithm a collection of labeled examples (a training set), and the algorithm then attempts to predict labels for new data points. The Titanic Kaggle challenge is an example of supervised learning, in particular classification.
  • unsupervised learning: Unsupervised learning occurs when there is no training set. A common type of unsupervised learning is clustering, where the computer automatically groups a bunch of data points into different “clusters” based on the data.

Installing R and RStudio

In order to follow this tutorial, you will need to have R set up on your computer. Here's a link to a download page: Inside R Download Page. I also recommend RStudio, which provides a simple interface for writing and executing R code: download it here. Both R and RStudio are totally free and easy to install.

Installing the required R packages

Go ahead and open up RStudio (or just R, if you don't want to use RStudio). For this tutorial, you need to install the caret package and the randomForest package (you only need to do this part once, even if you repeat the tutorial later).

Loading the required R packages

Now we have to load the packages into the working environment (unlike installing the packages, this step has to be done every time you restart your R session).

Loading in the data

Go the Kaggle download page to find the dataset. Download train.csv and test.csv, and be sure to save them to a place you can remember (I recommend a folder on your desktop called “Titanic”). You might need to sign up for Kaggle first (you should be using Kaggle anyway!)

To load in the data, you first set the R working directory to the place where you downloaded the data.

For example, I downloaded mine to a directory on my Desktop called Titanic, so I typed in

Now, in order to load the data, we will use the read.table function

This command reads in the file “train.csv”, using the delimiter “,”, including the header row as the column names, and assigns it to the R object trainSet.

Let's read in the testSet also:

Now, just for fun, let's take a look at the first few rows of the training set:

You'll see that each row has a column “Survived,” which is 1 if the person survived a 0 if they didn't. Now, let's compare the training set to the test set:

The big difference between the training set and the test set is that the training set is labeled, but the test set is unlabeled. On Kaggle, your job is to make predictions on the unlabeled test set, and Kaggle scores you based on the percentage of passengers you correctly label.

Testing for useful variables

The single most important factor in being able to build an effective model is not picking the best algorithm, or using the most advanced software package, or understanding the computational complexity of the singular value decomposition. Most of machine learning is really about picking the best features to use in the model. In machine learning, a “feature” is really just a variable or some sort of combination of variables (like the sum or product of two variables).

So in a classification model like the Titanic challenge, how do we pick the most useful variables to use? The most straightforward way (but by no means the only way) is to use crosstabs and conditional box plots.

Crosstabs for categorical variables

Crosstabs show the interactions between two variables in a very easy to read way. We want to know which variables are the best predictors for “Survived,” so we will look at the crosstabs between “Survived” and each other variable. In R, we use the table function:

Looking at this crosstab, we can see that “Pclass” could be a useful predictor of “Survived.” Why? The first column of the crosstab shows that of the passengers in Class 1, 136 survived and 80 died (ie. 63% of first class passengers survived). On the other hand, in Class 2, 87 survived and 97 died (ie. only 47% of second class passengers survived). Finally, in Class 3, 119 survived and 372 died (ie. only 24% of third class passengers survived). Damn, that's messed up.

We definitely want to use Pclass in our model, because it definitely has strong predictive value of whether someone survived or not. Now, you can repeat this process for the other categorical variables in the dataset, and decide which variables you want to include (I'll show you which ones I picked later in the post).

Plots for continuous variables

Plots are often a better way to identify useful continuous variables than crosstabs are (this is mostly because crosstabs aren't so natural for numerical variables). We will use “conditional” box plots to compare the distribution of each continuous variable, conditioned on whether the passengers survived or not ('Survived' = 1 or 'Survived' = 0). You may need to install the *fields* package first, just like you installed *caret* and *randomForest*.

The box plot of age for those who survived and and those who died are nearly the same. That means that Age probably did not have a large effect on whether one survived or not. The y-axis is Age and the x-axis is Survived (Survived = 1 if the person survived, 0 if not).

plot of chunk unnamed-chunk-10

Also, there are lots of NA’s. Let’s exclude the variable Age, because it probably doesn’t have a big impact on Survived, and also because the NA’s make it a little tricky to work with.

As you can see, the boxplots for Fares are much different for those who survived and those who died. Again, the y-axis is Fare and the x-axis is Survived (Survived = 1 if the person survived, 0 if not).

plot of chunk unnamed-chunk-10

Also, there are no NA’s for Fare. Let’s include this variable.

Training a model

Training the model uses a pretty simple command in caret, but it's important to understand each piece of the syntax. First, we have to convert Survived to a Factor data type, so that caret builds a classification instead of a regression model. Then, we use the train command to train the model (go figure!). You may be asking what a random forest algorithm is. You can think of it as training a bunch of different decision trees and having them vote
(remember, this is an irresponsibly fast tutorial). Random forests work pretty well in *lots* of different situations, so I often try them first.

Evaluating the model

For the purposes of this tutorial, we will use cross-validation scores to evaluate our model. Note: in real life (ie. not Kaggle), most data scientists also split the training set further into a training set and a validation set, but that is for another post.

What is cross-validation?

Cross-validation is a way to evaluate the performance of a model without needing any other data than the training data. It sounds complicated, but it's actually a pretty simple trick. Typically, you randomly split the training data into 5 equally sized pieces called “folds” (so each piece of the data contains 20% of the training data). Then, you train the model on 4/5 of the data, and check its accuracy on the 1/5 of the data you left out. You then repeat this process with each split of the data. At the end, you average the percentage accuracy across the five different splits of the data to get an average accuracy. Caret does this for you, and you can see the scores by looking at the model output:

There are few things to look at in the model output. The first thing to notice is where it says “The final value used for the model was mtry = 5.” The value “mtry” is a hyperparameter of the random forest model that determines how many variables the model uses to split the trees. The table shows different values of mtry along with their corresponding average accuracies (and a couple other metrics) under cross-validation. Caret automatically picks the value of the hyperparameter “mtry” that was the most accurate under cross validation. This approach is called using a “tuning grid” or a “grid search.”

As you can see, with mtry = 5, the average accuracy was 0.8170964, or about 82 percent. As long as the training set isn't too fundamentally different from the test set, we should expect that our accuracy on the test set should be around 82 percent, as well.

Making predictions on the test set

Using caret, it is easy to make predictions on the test set to upload to Kaggle. You just have to call the predict method on the model object you trained. Let's make the predictions on the test set and add them as a new column.

Uh, oh! There is an error here! When you get this type of error in R, it means that you are trying to assign a vector of one length to a vector of a different length, so the two vectors don't line up. So how do we fix this problem?

One annoying thing about caret and randomForest is that if there is missing data in the variables you are using to predict, it will just not return a prediction at all (and it won't throw an error!). So we have to find the missing data ourselves.

As you can see, the variable “Fare” has one NA value. Let's fill (“impute”“) that value in with the mean of the "Fare” column (there are better and fancier ways to do this, but that is for another post). We do this with an ifelse statement. Read it as follows: if an entry in the column “Fare” is NA, then replace it with the mean of the column (also removing the NA's when you take the mean). Otherwise, leave it the same.

Okay, now that we fixed that missing value, we can try again to run the predict method.

Let's remove the unnecessary columns that Kaggle doesn't want, and then write the testSet to a csv file.

Uploading your predictions to Kaggle

Uploading predictions is easy. Just go to the Kaggle page for the competition, click “Make a submission” on the sidebar, and select the file submission.csv. Click “Submit,” and then Kaggle will score your results on the test set.


Well, we didn't win, but we did pretty well. In fact, we beat several hundred other people and one of the benchmarks created by Kaggle! Our accuracy on the test set was 77%, which is pretty close to the cross-validation results of 82%. Not bad for our first model ever.

You can find the R code here.

Improving the model

This post only scratches the surface of what you can do with R and caret. Here are a few ideas for things to try in order to improve the model.

  • Try including different variables in the model: leave some out or add some in
  • Try combining variables into more useful variables: sometimes you can multiply or add variables together, or concatenate different categorical variables together
  • Try transforming the existing variables in clever ways: maybe turn a numerical variable into a categorical variable based on different ranges (e.g. 0-10, 10-90, 90-100)
  • Try a different algorithm: maybe neural networks, logistic regression or gradient boosting machines work better. Better yet, train a few different types of models and combine the results by averaging the probabilities (this is called ensembling)

Next steps

Okay, so you've done one machine learning classification tutorial and submitted a solution to Kaggle. That's an awesome start, and it's more than the vast majority of people ever do. So what's next? Here are a few things you can do:

  • Try another Kaggle competition! There are a few competitions out there that are great for learning, like Give me Some Credit] or Don’t Get Kicked. The forums contain lots of great advice and example solutions.
  • Learn more about predictive analytics and caret. The book Applied Predictve Modeling was written by Max Kuhn, the creator of caret. I haven't read it, but it comes highly recommended. His blog has also been incredibly useful to me.
  • Keep reading this blog! I will continue to post about practical machine learning. If you'd like, you can subscribe to my email list on the sidebar so that you never miss a post.

Sign up for my email list for updates.

* = required field


Leave a Reply

Your email address will not be published. Required fields are marked *