Update to latest modelr

This commit is contained in:
hadley 2016-06-15 14:10:48 -05:00
parent 9267d69365
commit 820405f66d
4 changed files with 104 additions and 81 deletions

View File

@ -76,7 +76,7 @@ mod <- my_model(df)
rmse(mod, df)
grid <- df %>% expand(x = seq_range(x, 50))
preds <- grid %>% add_predictions(y = mod)
preds <- grid %>% add_predictions(mod, var = "y")
df %>%
ggplot(aes(x, y)) +
@ -98,25 +98,25 @@ Obviously it does much worse. But in real-life you can't easily go out and reco
```{r}
boots <- rerun(100, df %>% mutate(y = true_model(x)))
mods <- map(boots, my_model)
preds <- map2_df(list(grid), mods, ~ add_predictions(.x, y = .y), .id = "id")
preds <- map2_df(list(grid), mods, add_predictions, .id = "id")
preds %>%
ggplot(aes(x, y, group = id)) +
ggplot(aes(x, pred, group = id)) +
geom_line(alpha = 1/3)
```
```{r}
boot <- bootstrap(df, 100)
boot <- modelr::bootstrap(df, 100)
mods <- boot$strap %>% map(safely(my_model)) %>% transpose()
ok <- mods$error %>% map_lgl(is.null)
```
```{r}
preds <- map2_df(list(grid), mods$result[ok], ~ add_predictions(.x, y = .y), .id = "id")
preds <- map2_df(list(grid), mods$result[ok], add_predictions, .id = "id")
preds %>%
ggplot(aes(x, y, group = id)) +
ggplot(aes(x, pred, group = id)) +
geom_line(alpha = 1/3)
```

View File

@ -188,9 +188,7 @@ grid
Next we add predicitons. We'll use `modelr::add_predictions()` which works in exactly the same way as `add_residuals()`, but just compute predictions (so doesn't need a data frame that contains the response variable:)
```{r}
grid <-
grid %>%
add_predictions(income = h)
grid <- grid %>% add_predictions(h, "income")
grid
```
@ -207,15 +205,13 @@ ggplot(heights, aes(height, income)) +
The flip-side of predictions are residuals. The predictions tell you what the model is doing; the residuals tell you what the model is missing. We can compute residuals with `add_residuals()`. Note that we computing residuals, you'll use the original dataset, not a manufactured grid. Otherwise where would you get the value of the response?
```{r}
heights <-
heights %>%
add_residuals(resid_h = h)
heights <- heights %>% add_residuals(h)
```
There are a few different ways to understand what the residuals tell us about the model. One way is to simply draw a frequency polygon to help us understand the spread of the residuals:
```{r}
ggplot(heights, aes(resid_h)) +
ggplot(heights, aes(resid)) +
geom_freqpoly(binwidth = 2000)
```
@ -226,14 +222,14 @@ Here you can see that the range of the residuals is quite large. (Note that by
For many problems, the sign of the residual (i.e. whether the prediction is too high or too low) isn't important, and you might just want to focus on the magnitude of the residuals. You can do that by plotting the absolute value:
```{r}
ggplot(heights, aes(abs(resid_h))) +
ggplot(heights, aes(abs(resid))) +
geom_freqpoly(binwidth = 2000, boundary = 0)
```
You can also explore how the residuals vary with other variables in the data:
```{r}
ggplot(heights, aes(height, income)) + geom_point()
ggplot(heights, aes(height, resid)) + geom_point()
```
Iterative plotting the residuals instead of the original response leads to a natual way of building up a complex model in simple steps, which we'll explore in detail in the next chapter.
@ -245,14 +241,14 @@ When you start dealing with many models, it's helpful to have some rough way of
One way to capture the quality of the model is to summarise the distribution of the residuals. For example, you could look at the quantiles of the absolute residuals. For this dataset, 25% of predictions are less than \$7,400 away, and 75% are less than \$25,800 away. That seems like quite a bit of error when predicting someone's income!
```{r}
quantile(abs(heights$resid_h), c(0.25, 0.75))
quantile(abs(heights$resid), c(0.25, 0.75))
range(heights$income)
```
You might be familiar with the $R^2$. That's a single number summary that rescales the variance of the residuals to between 0 (very bad) and 1 (very good):
```{r}
(var(heights$income) - var(heights$resid_h)) / var(heights$income)
(var(heights$income) - var(heights$resid)) / var(heights$income)
```
This is why the $R^2$ is sometimes interpreted as the amount of variation in the data explained by the model. Here we're explaining 3% of the total variation - not a lot! But I don't think worrying about the relative amount of variation explained is that useful; instead I think you need to consider whether the absolute amount of variation explained is useful for your project.
@ -278,7 +274,7 @@ The $R^2$ is an ok single number summary, but I prefer to think about the unscal
science, that we're not going to talk much about. <https://xkcd.com/882/>
1. It's often useful to recreate your initial EDA plots using residuals
instead of the original missing values. How does visualising `resid_h`
instead of the original missing values. How does visualising `resid`
instead of `height` change your understanding of the heights data?
## Multiple predictors
@ -311,7 +307,7 @@ What happens if we also include `sex` in the model?
h2 <- lm(income ~ height * sex, data = heights)
grid <- heights %>%
expand(height, sex) %>%
add_predictions(income = h2)
add_predictions(h2, "income")
ggplot(heights, aes(height, income)) +
geom_point() +
@ -327,10 +323,9 @@ Need to commment about predictions for tall women and short men - there is not a
h3 <- lm(income ~ height + sex, data = heights)
grid <- heights %>%
expand(height, sex) %>%
add_predictions(h2 = h2, h3 = h3) %>%
gather(model, prediction, h2:h3)
gather_predictions(h2, h3)
ggplot(grid, aes(height, prediction, colour = sex)) +
ggplot(grid, aes(height, pred, colour = sex)) +
geom_line() +
facet_wrap(~model)
```
@ -449,10 +444,9 @@ How can we visualise the results of this model? One way to think about it as a s
```{r}
grid <- heights_ed %>%
expand(height, education) %>%
add_predictions(he1 = he1, he2 = he2) %>%
gather(model, prediction, he1:he2)
gather_predictions(he1, he2)
ggplot(grid, aes(height, education, fill = prediction)) +
ggplot(grid, aes(height, education, fill = pred)) +
geom_raster() +
facet_wrap(~model)
```
@ -460,10 +454,10 @@ ggplot(grid, aes(height, education, fill = prediction)) +
It's easier to see what's going on in a line plot:
```{r}
ggplot(grid, aes(height, prediction, group = education)) +
ggplot(grid, aes(height, pred, group = education)) +
geom_line() +
facet_wrap(~model)
ggplot(grid, aes(education, prediction, group = height)) +
ggplot(grid, aes(education, pred, group = height)) +
geom_line() +
facet_wrap(~model)
```
@ -476,7 +470,7 @@ heights_ed %>%
height = seq_range(height, 10),
education = mean(education, na.rm = TRUE)
) %>%
add_predictions(income = he1) %>%
add_predictions(he1, "income") %>%
ggplot(aes(height, income)) +
geom_line()
@ -485,7 +479,7 @@ heights_ed %>%
height = mean(height, na.rm = TRUE),
education = seq_range(education, 10)
) %>%
add_predictions(income = he1) %>%
add_predictions(he1, "income") %>%
ggplot(aes(education, income)) +
geom_line()
```
@ -516,8 +510,7 @@ mod_e2 <- lm(income ~ education + I(education ^ 2) + I(education ^ 3), data = he
heights_ed %>%
expand(education) %>%
add_predictions(mod_e1 = mod_e1, mod_e2 = mod_e2) %>%
gather(model, pred, mod_e1:mod_e2) %>%
gather_predictions(mod_e1, mod_e2) %>%
ggplot(aes(education, pred, colour = model)) +
geom_point() +
geom_line()
@ -532,8 +525,7 @@ mod_e3 <- lm(income ~ poly(education, 3), data = heights_ed)
heights_ed %>%
expand(education) %>%
add_predictions(mod_e1 = mod_e1, mod_e2 = mod_e2, mod_e3 = mod_e3) %>%
gather(model, pred, mod_e1:mod_e3) %>%
gather_predictions(mod_e1, mod_e2, mod_e3) %>%
ggplot(aes(education, pred, colour = model)) +
geom_point() +
geom_line()
@ -543,8 +535,7 @@ However: there's one major problem with using `poly()`: outside the range of the
```{r}
data_frame(education = seq(5, 25)) %>%
add_predictions(mod_e1 = mod_e1, mod_e2 = mod_e2, mod_e3 = mod_e3) %>%
gather(model, pred, mod_e1:mod_e3) %>%
gather_predictions(mod_e1, mod_e2, mod_e3) %>%
ggplot(aes(education, pred, colour = model)) +
geom_line()
```
@ -558,8 +549,7 @@ mod_e2 <- lm(income ~ ns(education, 2), data = heights_ed)
mod_e3 <- lm(income ~ ns(education, 3), data = heights_ed)
data_frame(education = seq(5, 25)) %>%
add_predictions(mod_e1 = mod_e1, mod_e2 = mod_e2, mod_e3 = mod_e3) %>%
gather(model, pred, mod_e1:mod_e3) %>%
gather_predictions(mod_e1, mod_e2, mod_e3) %>%
ggplot(aes(education, pred, colour = model)) +
geom_line()
```

View File

@ -80,17 +80,17 @@ One way to remove this strong pattern is to fit a model that "explains" (i.e. at
```{r}
mod <- lm(n ~ wday, data = daily)
daily <- daily %>% add_residuals(n_resid = mod)
daily <- daily %>% add_residuals(mod)
daily %>%
ggplot(aes(date, n_resid)) +
ggplot(aes(date, resid)) +
geom_hline(yintercept = 0, size = 2, colour = "white") +
geom_line()
daily %>%
expand(wday) %>%
add_predictions(n_pred = mod) %>%
ggplot(aes(wday, n_pred)) +
add_predictions(mod) %>%
ggplot(aes(wday, pred)) +
geom_point()
```
@ -102,7 +102,7 @@ Note the change in the y-axis: now we are seeing the deviation from the expected
to see:
```{r}
ggplot(daily, aes(date, n_resid, colour = wday)) +
ggplot(daily, aes(date, resid, colour = wday)) +
geom_hline(yintercept = 0, size = 2, colour = "white") +
geom_line()
```
@ -117,7 +117,7 @@ Note the change in the y-axis: now we are seeing the deviation from the expected
1. There are some days with far fewer flights than expected:
```{r}
daily %>% filter(n_resid < -100)
daily %>% filter(resid < -100)
```
If you're familiar with American public holidays, you might spot New Year's
@ -130,7 +130,7 @@ Note the change in the y-axis: now we are seeing the deviation from the expected
```{r}
daily %>%
ggplot(aes(date, n_resid)) +
ggplot(aes(date, resid)) +
geom_hline(yintercept = 0, size = 2, colour = "white") +
geom_line(colour = "grey50") +
geom_smooth(se = FALSE, span = 0.20)
@ -186,12 +186,12 @@ daily %>%
It looks like there is significant variation across the terms, so fitting a separate day of week effect for each term is reasonable. This improves our model, but not as much as we might hope:
```{r}
mod1 <- lm(n ~ wday, data = daily)
mod2 <- lm(n ~ wday * term, data = daily)
daily$n_resid2 <- resid(mod2)
daily %>%
gather(model, prediction, n_resid, n_resid2) %>%
ggplot(aes(date, prediction, colour = model)) +
gather_residuals(mod1, mod2) %>%
ggplot(aes(date, resid, colour = model)) +
geom_line(alpha = 0.75)
```
@ -199,9 +199,10 @@ That's because this model is basically calculating an average for each combinati
```{r, warn = FALSE}
mod3 <- MASS::rlm(n ~ wday * term, data = daily)
daily <- daily %>% add_residuals(n_resid3 = mod3)
ggplot(daily, aes(date, n_resid3)) +
daily %>%
add_residuals(mod3, "resid") %>%
ggplot(aes(date, resid)) +
geom_hline(yintercept = 0, size = 2, colour = "white") +
geom_line()
```
@ -221,9 +222,9 @@ mod <- MASS::rlm(n ~ wday * yday(date), data = daily)
grid <- daily %>%
tidyr::expand(wday, date = seq_range(date, n = 13)) %>%
add_predictions(mod = mod)
add_predictions(mod)
ggplot(grid, aes(date, mod, colour = wday)) +
ggplot(grid, aes(date, pred, colour = wday)) +
geom_line() +
geom_point()
```
@ -238,8 +239,8 @@ mod <- MASS::rlm(n ~ wday * ns(date, 5), data = daily)
daily %>%
tidyr::expand(wday, date = seq_range(date, n = 13)) %>%
add_predictions(mod = mod) %>%
ggplot(aes(date, mod, colour = wday)) +
add_predictions(mod) %>%
ggplot(aes(date, pred, colour = wday)) +
geom_line() +
geom_point()
```
@ -272,7 +273,7 @@ wday2 <- function(x) wday(x, label = TRUE)
mod3 <- lm(n ~ wday2(date) * term(date), data = daily)
daily %>%
expand(date) %>%
add_predictions(pred = mod3)
add_predictions(mod3)
```
I think this is fine to do provided that you've carefully checked that the functions do what you think they do (i.e. with a visualisation). There are two disadvantages:

View File

@ -68,13 +68,13 @@ nz %>%
nz_mod <- lm(lifeExp ~ year, data = nz)
nz %>%
add_predictions(pred = nz_mod) %>%
add_predictions(nz_mod) %>%
ggplot(aes(year, pred)) +
geom_line() +
ggtitle("Linear trend + ")
nz %>%
add_residuals(resid = nz_mod) %>%
add_residuals(nz_mod) %>%
ggplot(aes(year, resid)) +
geom_hline(yintercept = 0, colour = "white", size = 3) +
geom_line() +
@ -144,7 +144,7 @@ Previously we computed the residuals of a single model with a single dataset. No
```{r}
by_country <- by_country %>% mutate(
resids = map2(data, model, ~ add_residuals(.x, resid = .y))
resids = map2(data, model, add_residuals)
)
by_country
```
@ -256,45 +256,77 @@ We see two main effects here: the tragedies of the HIV/AIDS epidemic, and the Rw
## List-columns
The idea of a list column is powerful. The contract of a data frame is that it's a named list of vectors, where each vector has the same length. A list is a vector, and a list can contain anything, so you can put anything in a list-column of a data frame.
Now that you've seen a basic workflow for managing many models, lets dive back into some of the details. In this section, we'll dive into the notional of the list-column in a little more detail, and then we'll give a few more details about `nest()`/`unnest()`.
Generally, you should make sure that your list columns are homogeneous: each element should contain the same type of thing. There are no checks to make sure this is true, but if you use purrr and remember what you've learned about type-stable functions you should find it happens naturally.
It's only in the last year that I've really appreciated the idea of the list-column. List-columns are implicit in the defintion of the data frame: a data frame is a named list of equal length vectors. A list is a vector, so it's always been legitimate to put use a list as a column of a data frame.
### Compared to base R
List columns are possible in base R, but conventions in `data.frame()` make creating and printing them a bit of a headache:
```{r, error = TRUE}
# Doesn't work
data.frame(x = list(1:2, 3:5))
# Works, but doesn't print particularly well
data.frame(x = I(list(1:2, 3:5)), y = c("1, 2", "3, 4, 5"))
```
The functions in tibble don't have this problem:
However, base R doesn't make it easier to create list-columns, and `data.frame()` treats a list as a list of columns:.
```{r}
data_frame(x = list(1:2, 3:5), y = c("1, 2", "3, 4, 5"))
data.frame(x = list(1:3, 3:5))
```
### With `mutate()` and `summarise()`
You might find yourself creating list-columns with mutate and summarise. For example:
You can prevent `data.frame()` from doing this with `I()`, but the result doesn't print particularly well:
```{r}
data_frame(x = c("a,b,c", "d,e,f,g")) %>%
mutate(x = stringr::str_split(x, ","))
data.frame(
x = I(list(1:3, 3:5)),
y = c("1, 2", "3, 4, 5")
)
```
`unnest()` knows how to handle these lists of vectors as well as lists of data frames.
Tibble alleviates this problem by not messing with the inputs to `data_frame()`, and by providing a better print method:
```{r}
data_frame(x = c("a,b,c", "d,e,f,g")) %>%
mutate(x = stringr::str_split(x, ",")) %>%
data_frame(
x = list(1:3, 3:5),
y = c("1, 2", "3, 4, 5")
)
```
Generally, when creating list-columns, you should make sure they're homogeneous: each element should contain the same type of thing. There are no checks to make sure this is true, but if you use purrr and remember what you've learned about type-stable functions you should find it happens naturally.
You've seen two importany way of generating list-columns in the previous case study:
1. Using `tidyr::nest()` to convert a grouped data frame into a nested data
frame where you have list-column of data frames.
1. Using `mutate()` with `purrr::map()` to transform a (e.g.) a list of data
frames into a list of models.
There are two other useful ways to generate list-columns with dplyr:
1. With `mutate()` and vectorised functions that return a list.
1. With `summarise()` and aggregate functions that return an arbitrary
number of results.
These are described below.
### List-columns from vectorised functions
Some useful fuctions take an atomic vector and return a list. For example, earlier you learned about `stringr::str_split()` which takes a character vector and returns a list of charcter vectors.
```{r}
df <- data_frame(x1 = c("a,b,c", "d,e,f,g"))
df %>%
mutate(x2 = stringr::str_split(x1, ","))
```
`unnest()` knows how to handle these lists of vectors:
```{r}
df %>%
mutate(x2 = stringr::str_split(x1, ",")) %>%
unnest()
```
(If you find yourself using this pattern alot, make sure to check out `separate_rows()`)
(If you find yourself using this pattern alot, make sure to check out `tidyr:separate_rows()` which is a wrapper around this common pattern).
### List-columns with multivalued summaries
One restriction of `summarise()` is that it only works with aggregate functions that return a single value. That means that you can't use it with
This can be useful for summary functions like `quantile()` that return a vector of values: