Use new data_grid instead of expand

This commit is contained in:
hadley 2016-07-27 16:42:14 -05:00
parent 40052d1d52
commit 4c789ab8e9
2 changed files with 15 additions and 16 deletions

View File

@ -45,7 +45,7 @@ The goal of a model is not to uncover truth, but to discover a simple approximat
We need a couple of packages specifically designed for modelling, and all the packages you've used before for EDA. We need a couple of packages specifically designed for modelling, and all the packages you've used before for EDA.
```{r setup, message = FALSE} ```{r setup, message = FALSE, cache = FALSE}
# Modelling functions # Modelling functions
library(modelr) library(modelr)
options(na.action = na.warn) options(na.action = na.warn)
@ -53,7 +53,6 @@ options(na.action = na.warn)
# EDA tools # EDA tools
library(ggplot2) library(ggplot2)
library(dplyr) library(dplyr)
library(tidyr)
``` ```
## A simple model ## A simple model
@ -243,10 +242,10 @@ It's also useful to see what the model doesn't capture, the so called residuals
### Predictions ### Predictions
To visualise the predictions from a model, we start by generating an evenly spaced grid of values that covers the region where our data lies. The easiest way to do that is to use `tidyr::expand()`. Its first argument is a data frame, and for each subsequent argument it finds the unique variables and then generates all combinations: To visualise the predictions from a model, we start by generating an evenly spaced grid of values that covers the region where our data lies. The easiest way to do that is to use `modelr::data_grid()`. Its first argument is a data frame, and for each subsequent argument it finds the unique variables and then generates all combinations:
```{r} ```{r}
grid <- sim1 %>% expand(x) grid <- sim1 %>% data_grid(x)
grid grid
``` ```
@ -377,7 +376,7 @@ We can fit a model to it, and generate predictions:
mod2 <- lm(y ~ x, data = sim2) mod2 <- lm(y ~ x, data = sim2)
grid <- sim2 %>% grid <- sim2 %>%
expand(x) %>% data_grid(x) %>%
add_predictions(mod2) add_predictions(mod2)
grid grid
``` ```
@ -416,7 +415,7 @@ When you add variables with `+`, the model will estimate each effect independent
To visualise these models we need two new tricks: To visualise these models we need two new tricks:
1. We have two predictors, so we need to give `expand()` two variables. 1. We have two predictors, so we need to give `data_grid()` both variables.
It finds all the unique values of `x1` and `x2` and then generates all It finds all the unique values of `x1` and `x2` and then generates all
combinations. combinations.
@ -429,7 +428,7 @@ Together this gives us:
```{r} ```{r}
grid <- sim3 %>% grid <- sim3 %>%
expand(x1, x2) %>% data_grid(x1, x2) %>%
gather_predictions(mod1, mod2) gather_predictions(mod1, mod2)
grid grid
``` ```
@ -467,7 +466,7 @@ mod1 <- lm(y ~ x1 + x2, data = sim4)
mod2 <- lm(y ~ x1 * x2, data = sim4) mod2 <- lm(y ~ x1 * x2, data = sim4)
grid <- sim4 %>% grid <- sim4 %>%
expand( data_grid(
x1 = seq_range(x1, 5), x1 = seq_range(x1, 5),
x2 = seq_range(x2, 5) x2 = seq_range(x2, 5)
) %>% ) %>%
@ -475,7 +474,7 @@ grid <- sim4 %>%
grid grid
``` ```
Note my use of `seq_range()` inside `expand()`. Instead of using every unique value of `x`, I'm going to use a regularly spaced grid of five values between the minimum and maximum numbers. It's probably not super important here, but it's a useful technique in general. There are two other useful arguments to `seq_range()`: Note my use of `seq_range()` inside `data_grid()`. Instead of using every unique value of `x`, I'm going to use a regularly spaced grid of five values between the minimum and maximum numbers. It's probably not super important here, but it's a useful technique in general. There are two other useful arguments to `seq_range()`:
* `pretty = TRUE` will generate a "pretty" sequence, i.e. something that looks * `pretty = TRUE` will generate a "pretty" sequence, i.e. something that looks
nice to the human eye. This is useful if you want to produce tables of nice to the human eye. This is useful if you want to produce tables of
@ -554,7 +553,7 @@ model_matrix(df, y ~ poly(x, 2))
However there's one major problem with using `poly()`: outside the range of the data, polynomials rapidly shoot off to positive or negative infinity. One safer alternative is to use the natural spline, `splines::ns()`. However there's one major problem with using `poly()`: outside the range of the data, polynomials rapidly shoot off to positive or negative infinity. One safer alternative is to use the natural spline, `splines::ns()`.
```{r} ```{r, cache = FALSE}
library(splines) library(splines)
model_matrix(df, y ~ ns(x, 2)) model_matrix(df, y ~ ns(x, 2))
``` ```
@ -581,7 +580,7 @@ mod4 <- lm(y ~ ns(x, 4), data = sim5)
mod5 <- lm(y ~ ns(x, 5), data = sim5) mod5 <- lm(y ~ ns(x, 5), data = sim5)
grid <- sim5 %>% grid <- sim5 %>%
expand(x = seq_range(x, n = 50, expand = 0.1)) %>% data_grid(x = seq_range(x, n = 50, expand = 0.1)) %>%
gather_predictions(mod1, mod2, mod3, mod4, mod5, .pred = "y") gather_predictions(mod1, mod2, mod3, mod4, mod5, .pred = "y")
ggplot(sim5, aes(x, y)) + ggplot(sim5, aes(x, y)) +
@ -610,7 +609,7 @@ sim6 <- tibble(
mod <- lm(y ~ x1 * x2, data = sim6) mod <- lm(y ~ x1 * x2, data = sim6)
grid <- sim6 %>% grid <- sim6 %>%
expand( data_grid(
x1 = seq_range(x1, 10), x1 = seq_range(x1, 10),
x2 = c(0, 0.5, 1, 1.5) x2 = c(0, 0.5, 1, 1.5)
) %>% ) %>%

View File

@ -86,7 +86,7 @@ Then we look at what the model tells us about the data. Note that I back transfo
```{r} ```{r}
grid <- diamonds2 %>% grid <- diamonds2 %>%
expand(carat = seq_range(carat, 20)) %>% data_grid(carat = seq_range(carat, 20)) %>%
mutate(lcarat = log2(carat)) %>% mutate(lcarat = log2(carat)) %>%
add_predictions(mod_diamond, "lprice") %>% add_predictions(mod_diamond, "lprice") %>%
mutate(price = 2 ^ lprice) mutate(price = 2 ^ lprice)
@ -213,7 +213,7 @@ One way to remove this strong pattern is to use a model. First, we fit the model
mod <- lm(n ~ wday, data = daily) mod <- lm(n ~ wday, data = daily)
grid <- daily %>% grid <- daily %>%
expand(wday) %>% data_grid(wday) %>%
add_predictions(mod, "n") add_predictions(mod, "n")
ggplot(daily, aes(wday, n)) + ggplot(daily, aes(wday, n)) +
@ -340,7 +340,7 @@ We can see the problem by overlaying the predictions from the model on to the ra
```{r} ```{r}
grid <- daily %>% grid <- daily %>%
expand(wday, term) %>% data_grid(wday, term) %>%
add_predictions(mod2, "n") add_predictions(mod2, "n")
ggplot(daily, aes(wday, n)) + ggplot(daily, aes(wday, n)) +
@ -372,7 +372,7 @@ library(splines)
mod <- MASS::rlm(n ~ wday * ns(date, 5), data = daily) mod <- MASS::rlm(n ~ wday * ns(date, 5), data = daily)
daily %>% daily %>%
tidyr::expand(wday, date = seq_range(date, n = 13)) %>% data_grid(wday, date = seq_range(date, n = 13)) %>%
add_predictions(mod) %>% add_predictions(mod) %>%
ggplot(aes(date, pred, colour = wday)) + ggplot(aes(date, pred, colour = wday)) +
geom_line() + geom_line() +