r4ds/model-assess.Rmd

173 lines
6.8 KiB
Plaintext
Raw Normal View History

2015-12-12 02:34:20 +08:00
# Model assessment
2015-12-12 03:28:10 +08:00
```{r setup-model, include=FALSE}
2015-12-07 23:06:19 +08:00
library(purrr)
set.seed(1014)
options(digits = 3)
```
In this chapter, you'll turn the tools of multiple models towards model assessment: learning how the model performs when giving new data. So far we've focussed on models as tools for description, using models to help us understand the patterns in the data we have collected so far. But ideally a model will do more than just describe what we have seen so far - it will also help predict what will come next.
2015-12-07 23:06:19 +08:00
In other words, we want a model that doesn't just perform well on the sample, but also accurately summarises the underlying population.
In some industries this is primarily the use of models: you spend relatively little time fitting the model compared to how many times you use it.
Models as pets vs. models as livestock.
There are two basic ways that a model can fail with new data:
* You can under- or over-fit the model. Underfitting is where you fail
to model and important trend: you leave too much in the residuals, and not
enough in the model. Overfitting is the opposite: you fit a trend to
what is actually random noise: you've too put much model and not left
enough in the residuals. Generally overfitting tends to be more of a
problem than underfitting.
* The process that generates the data might change. There's nothing the
model can do about this. You can protect yourself against this to some
extent by creating models that you understand and applying your knowledge
to the problem. Are these fundamentals likely to change? If you have
a model that you are going to use again and again for a long time, you
need to plan to maintain the model, regularly checking that it still
makes sense. i.e. is the population the same?
2016-06-14 22:02:17 +08:00
<http://research.google.com/pubs/pub43146.html>
<http://www.wired.com/2015/10/can-learn-epic-failure-google-flu-trends/>
The most common problem with a model that causes it to do poorly with new data is overfitting.
Obviously, there's a bit of a problem here: we don't have new data with which to check the model, and even if we did, we'd presumably use it to make the model better in the first place. One powerful technique of approaches can help us get around this problem: resampling.
There are two main resampling techniques that we're going to cover.
* We will use __cross-validation__ to assess model quality. In
cross-validation, you split the data into test and training sets. You fit
the data to the training set, and evaluate it on the test set. This avoids
intrinsic bias of using the same data to both fit the model and assess it's
quality. However it introduces a new bias: you're not using all the data to
fit the model so it's going to be quite as good as it could be.
* We will use __boostrapping__ to understand how stable (or how variable)
the model is. If you sample data from the same population multiple times,
how much does your model vary? Instead of going back to collect new data,
you can use the best estimate of the population data: the data you've
collected so far. The amazing idea of the bootstrap is that you can resample
from the data you already have.
There are lots of high-level helpers to do these resampling methods in R. We're going to use the tools provided by the modelr package because they are explicit - you'll see exactly what's going on at each step.
<http://topepo.github.io/caret>. [Applied Predictive Modeling](https://amzn.com/1461468485), by Max Kuhn and Kjell Johnson.
2016-06-14 22:02:17 +08:00
If you're competing in competitions, like Kaggle, that are predominantly about creating good predicitons, developing a good strategy for avoiding overfitting is very important. Otherwise you risk tricking yourself into thinking that you have a good model, when in reality you just have a model that does a good job of fitting your data.
2016-06-14 22:02:17 +08:00
There is a closely related family that uses a similar idea: model ensembles. However, instead of trying to find the best models, ensembles make use of all the models, acknowledging that even models that don't fit all the data particularly well can still model some subsets well. In general, you can think of model ensemble techniques as functions that take a list of models, and a return a single model that attempts to take the best part of each.
### Prerequisites
```{r setup, message = FALSE}
# Standard data manipulation and visulisation
library(dplyr)
library(ggplot2)
# Tools for working with models
library(broom)
library(modelr)
2016-06-15 21:27:35 +08:00
library(splines)
2016-06-14 22:02:17 +08:00
# Tools for working with lots of models
library(purrr)
library(tidyr)
```
## Overfitting
Both bootstrapping and cross-validation help us to spot and remedy the problem of __over fitting__, where the model fits the data we've seen so far extremely well, but does a bad job of generalising to new data.
A classic example of over-fitting is to use a spline with too many degrees of freedom.
Bias - variance tradeoff. Simpler = more biased. Complex = more variable. Occam's razor.
```{r}
2016-06-14 22:02:17 +08:00
true_model <- function(x) {
1 + 2 * x + rnorm(length(x), sd = 0.25)
}
2016-06-14 22:02:17 +08:00
df <- data_frame(
x = seq(0, 1, length = 20),
y = true_model(x)
)
df %>%
ggplot(aes(x, y)) +
geom_point()
```
We can create a model that fits this data very well:
```{r, message = FALSE}
2016-06-14 22:02:17 +08:00
library(splines)
my_model <- function(df) {
lm(y ~ ns(x, 5), data = df)
}
mod <- my_model(df)
rmse(mod, df)
grid <- df %>% expand(x = seq_range(x, 50))
2016-06-16 03:10:48 +08:00
preds <- grid %>% add_predictions(mod, var = "y")
2016-06-14 22:02:17 +08:00
df %>%
ggplot(aes(x, y)) +
geom_line(data = preds) +
geom_point()
```
2016-06-14 22:02:17 +08:00
But do you think this model will do well if we apply it to new data from the same population?
This case is a simulation, so we could just resimulate data from the same process and see how well it does:
```{r}
2016-06-14 22:02:17 +08:00
df2 <- df %>% mutate(y = true_model(x))
rmse(mod, df2)
```
2016-06-14 22:02:17 +08:00
Obviously it does much worse. But in real-life you can't easily go out and recollect your data. There are two approach to help you get around this problem. I'll introduce them briefly here, and then we'll go into more depth in the following sections.
## Bootstrapping
## Cross-validation
```{r}
boot <- bootstrap(df, 100) %>% mutate(
mod = map(strap, my_model),
pred = map2(list(grid), mod, add_predictions)
)
2016-06-14 22:02:17 +08:00
boot %>%
unnest(pred, .id = "id") %>%
2016-06-16 03:10:48 +08:00
ggplot(aes(x, pred, group = id)) +
2016-06-14 22:02:17 +08:00
geom_line(alpha = 1/3)
```
2016-06-14 22:02:17 +08:00
(You might notice that while each individual model varies a lot, the average of all the models seems like it's pretty good. That gives rise to a model ensemble technique called model averaging.)
2016-06-14 22:02:17 +08:00
We could instead use cross-validation to focus on a summary of model quality. It basically works like this:
2016-06-14 22:02:17 +08:00
```{r}
cv <- crossv_mcmc(df, 100, test = 0.3) %>%
mutate(
mod = map(train, my_model),
rmse = map2_dbl(mod, test, rmse)
)
cv %>%
ggplot(aes(rmse)) +
geom_ref_line(v = rmse(mod, df)) +
2016-06-14 22:02:17 +08:00
geom_freqpoly(binwidth = 0.05) +
geom_rug()
mean(cv$rmse)
```