More on model vis

This commit is contained in:
hadley 2016-05-10 20:21:23 -05:00
parent 46d80495fd
commit f62008d075
1 changed files with 122 additions and 44 deletions

View File

@ -3,6 +3,8 @@ library(broom)
library(ggplot2)
library(dplyr)
library(lubridate)
library(tidyr)
library(nycflights13)
```
# Model visualisation
@ -15,6 +17,8 @@ In this chapter we will explore model visualisation from two different sides:
We're going to give you a basic strategy, and point you to places to learn more. The key is to think about data generated from your model as regular data - you're going to want to manipulate it and visualise it in many different ways.
Centered around looking at residuals and looking at predictions. You'll see those here applied to linear models (and some minor variations), but it's a flexible technique since every model can generate predictions and residuals.
Being good at modelling is a mixture of having some good general principles and having a big toolbox of techniques. Here we'll focus on general techniques to help you undertand what your model is telling you.
Focus on constructing models that help you better understand the data. This will generally lead to models that predict better. But you have to beware of overfitting the data - in the next section we'll discuss some formal methods. But a healthy dose of scepticism is also a powerful: do you believe that a pattern you see in your sample is going to generalise to a wider population?
@ -102,7 +106,7 @@ Note the change in the y-axis: now we are seeing the deviation from the expected
If you're familiar with American public holidays, you might spot New Year's
day, July 4th, Thanksgiving and Christmas. There are some others that don't
seem to correspond immediately to public holidays. You'll figure those out
seem to correspond immediately to public holidays. You'll work on those
in the exercise below.
1. There seems to be some smoother long term trend over the course of a year.
@ -131,7 +135,7 @@ daily %>%
scale_x_datetime(NULL, date_breaks = "1 month", date_labels = "%b")
```
So it looks like summer holidays are from early June to late August. That seems to line up fairly well with the [state's school terms](http://schools.nyc.gov/Calendar/2013-2014+School+Year+Calendars.htm): summer break is Jun 26 - Sep 9. So lets add a "term" variable to attemp to control for that. I manually tweaked the dates to get nice breaks in the plot.
So it looks like summer holidays are from early June to late August. That seems to line up fairly well with the [state's school terms](http://schools.nyc.gov/Calendar/2013-2014+School+Year+Calendars.htm): summer break is Jun 26 - Sep 9. So lets add a "term" variable to attemp to control for that.
```{r}
daily <- daily %>%
@ -148,6 +152,8 @@ daily %>%
scale_x_datetime(NULL, date_breaks = "1 month", date_labels = "%b")
```
(I manually tweaked the dates to get nice breaks in the plot.)
It's useful to see how this new variable affects the other days of the week:
```{r}
@ -179,17 +185,18 @@ middles <- daily %>%
)
middles %>%
ggplot(aes(wday, colour = term)) +
geom_point(aes(y = mean, shape = "mean")) +
geom_point(aes(y = median, shape = "median")) +
ggplot(aes(wday)) +
geom_linerange(aes(ymin = mean, ymax = median), colour = "grey70") +
geom_point(aes(y = mean, colour = "mean")) +
geom_point(aes(y = median, colour = "median")) +
facet_wrap(~ term)
```
We can reduce this problem by switch to a robust model fitted by `MASS::rlm()`. A robust model is a variation of the linear model which you can think of a fitting medians, instead of means (it's a bit more complicated than that, but that's a reasonable intuition). This greatly reduces the impact of the outliers on our estimates, and gives a result that does a good job of removing the day of week pattern:
We can reduce this problem by switching to a robust model fitted by `MASS::rlm()`. A robust model is a variation of the linear model which you can think of a fitting medians, instead of means (it's a bit more complicated than that, but that's a reasonable intuition). This greatly reduces the impact of the outliers on our estimates, and gives a result that does a good job of removing the day of week pattern:
```{r}
```{r, warn=FALSE}
mod2 <- MASS::rlm(n ~ wday * term, data = daily)
daily$n_resid2 <- resid(mod2)
daily <- daily %>% add_residuals(n_resid2 = mod2)
ggplot(daily, aes(date, n_resid2)) +
geom_hline(yintercept = 0, size = 2, colour = "white") +
@ -212,10 +219,10 @@ It's now much easier to see the long-term trend, and the positive and negative o
daily %>% filter(n_resid2 > 80)
```
1. Create a new variable that splits the `wday` variable in to terms only
for Saturdays, i.e. it should have `Sat-summer`, `Sat-spring`,
`Sat-fall`. How does this model compare with the model with every
combination of `wday` and `term`?
1. Create a new variable that splits the `wday` variable into terms, but only
for Saturdays, i.e. it should have `Thurs`, `Fri`, but `Sat-summer`,
`Sat-spring`, `Sat-fall`. How does this model compare with the model with
every combination of `wday` and `term`?
1. Create a new wday variable that combines the day of week, term
(for Saturdays), and public holidays. What do the residuals of
@ -238,49 +245,109 @@ It's now much easier to see the long-term trend, and the positive and negative o
Focus on predictions from a model because this works for any type of model. Visualising parameters can also be useful, but tends to be most useful when you have many similar models. Visualising predictions works regardless of the model family.
```{r}
```
Visualising high-dimensional models is challenging. You'll need to partition off a useable slice at a time.
### `rlm()` vs `lm()`
Let's start by exploring the difference between the `lm()` and `rlm()` predictions for the day of week effects. We'll first re-fit the models, just so we have them handy:
```{r}
library(tidyr)
mod1 <- lm(n ~ wday * term, data = daily)
mod2 <- MASS::rlm(n ~ wday * term, data = daily)
```
date_vars <- function(df) {
df %>% mutate(
Next, we need to generate a grid of values to compute predictions for. The easiest way to do that is to use `tidyr::expand()`. It's first argument is a data frame, and for each subsequent argument it finds the unique variables and then generates all combinations:
```{r}
grid <-
daily %>%
tidyr::expand(wday, term)
grid
```
Next we add predicitons. We'll use `modelr::add_predictions()` which works in exactly the same way as `add_residuals()`, but just compute predictions (so doesn't need a data frame that contains the response variable:)
```{r}
grid <-
grid %>%
add_predictions(linear = mod1, robust = mod2)
grid
```
And then we plot the predictions. Plotting predictions is usually the hardest bit and you'll need to try a few times before you get a plot that is most informative. Depending on your model it's quite possible that you'll need multiple plots to fully convey what the model is telling you about the data.
```{r}
grid %>%
ggplot(aes(wday)) +
geom_linerange(aes(ymin = linear, ymax = robust), colour = "grey70") +
geom_point(aes(y = linear, colour = "linear")) +
geom_point(aes(y = robust, colour = "robust")) +
facet_wrap(~ term)
```
### Computed variables
```{r}
daily %>%
expand(date) %>%
mutate(
term = cut(date,
breaks = as.POSIXct(ymd(20130101, 20130605, 20130825, 20140101)),
labels = c("spring", "summer", "fall")
),
wday = wday(date, label = TRUE)
)
}
daily %>%
expand(date) %>%
date_vars() %>%
) %>%
add_predictions(pred = mod2) %>%
ggplot(aes(date, pred)) +
geom_line()
daily %>%
expand(date, wday = "Sat", term = "spring") %>%
add_predictions(pred = mod2) %>%
ggplot(aes(date, pred)) +
geom_line()
daily %>%
expand(wday, term) %>%
add_predictions(pred = mod2) %>%
ggplot(aes(wday, pred, colour = term)) +
geom_point() +
geom_line(aes(group = term))
```
If you're experimenting with many models and many visualisations, it's a good idea to bundle the creation of variables up into a function so there's no chance of accidentally applying a different transformation in different places.
### Nested variables
Another case that occassionally crops up is nested variables: you have an identifier that is locally unique, not globally unique. For example you might have this data about students in schools:
```{r}
students <- tibble::frame_data(
~student_id, ~school_id,
1, 1,
2, 1,
1, 2,
1, 3,
2, 3,
3, 3
)
```
The student id only makes sense in the context of the school: it doesn't make sense to generate every combination of student and school. You can use `nesting()` for this case:
```{r}
students %>% expand(nesting(school_id, student_id))
```
### Continuous variables
```{r}
grid <- nlme::Oxboys %>%
as_data_frame() %>%
tidyr::expand(Subject, age = seq_range(age, 2))
mod <- nlme::lme(height ~ age, random = ~1 | Subject, data = nlme::Oxboys)
grid %>%
add_predictions(mod = mod) %>%
ggplot(aes(age, mod)) +
geom_line(aes(group = Subject))
```
### Exercises
1. How does the model of model coefficients compare to the plot of means
and medians computed "by hand" in the previous chapter. Create a plot
the highlights the differences and similarities.
## Delays and weather
```{r}
@ -291,9 +358,20 @@ hourly <- flights %>%
) %>%
inner_join(weather, by = c("origin", "time_hour"))
ggplot(hourly, aes(time_hour, delay)) +
geom_point()
# ggplot(hourly, aes(time_hour, delay)) +
# geom_point()
#
# ggplot(hourly, aes(hour(time_hour), sign(delay) * sqrt(abs(delay)))) +
# geom_boxplot(aes(group = hour(time_hour)))
#
# hourly %>%
# filter(wind_speed < 999) %>%
# ggplot(aes(temp, delay)) +
# geom_point() +
# geom_smooth()
ggplot(hourly, aes(hour(time_hour), delay)) +
geom_boxplot(aes(group = hour(time_hour)))
```
## Learning more
<https://cran.rstudio.com/web/packages/condvis/>