模型构建

《区域水环境污染数据分析实践》
Data analysis practice of regional water environment pollution

苏命、王为东
中国科学院大学资源与环境学院
中国科学院生态环境研究中心

2024-04-09

tidymodels主要步骤

何为tidymodels?

library(tidymodels)

整体思路

整体思路

整体思路

整体思路

整体思路

整体思路

整体思路

相关包的安装

# Install the packages for the workshop
pkgs <- 
  c("bonsai", "doParallel", "embed", "finetune", "lightgbm", "lme4",
    "plumber", "probably", "ranger", "rpart", "rpart.plot", "rules",
    "splines2", "stacks", "text2vec", "textrecipes", "tidymodels", 
    "vetiver", "remotes")

install.packages(pkgs)



Data on Chicago taxi trips

library(tidymodels)
taxi
# A tibble: 10,000 × 7
   tip   distance company                      local dow   month  hour
   <fct>    <dbl> <fct>                        <fct> <fct> <fct> <int>
 1 yes      17.2  Chicago Independents         no    Thu   Feb      16
 2 yes       0.88 City Service                 yes   Thu   Mar       8
 3 yes      18.1  other                        no    Mon   Feb      18
 4 yes      20.7  Chicago Independents         no    Mon   Apr       8
 5 yes      12.2  Chicago Independents         no    Sun   Mar      21
 6 yes       0.94 Sun Taxi                     yes   Sat   Apr      23
 7 yes      17.5  Flash Cab                    no    Fri   Mar      12
 8 yes      17.7  other                        no    Sun   Jan       6
 9 yes       1.85 Taxicab Insurance Agency Llc no    Fri   Apr      12
10 yes       1.47 City Service                 no    Tue   Mar      14
# ℹ 9,990 more rows

数据分割与使用

对于机器学习,我们通常将数据分成训练集和测试集:

  • 训练集用于估计模型参数。
  • 测试集用于独立评估模型性能。

在训练过程中不要使用测试集。

The initial split

set.seed(123)
taxi_split <- initial_split(taxi)
taxi_split
<Training/Testing/Total>
<7500/2500/10000>

Accessing the data

taxi_train <- training(taxi_split)
taxi_test <- testing(taxi_split)

The training set

taxi_train
# A tibble: 7,500 × 7
   tip   distance company                   local dow   month  hour
   <fct>    <dbl> <fct>                     <fct> <fct> <fct> <int>
 1 yes       0.7  Taxi Affiliation Services yes   Tue   Mar      18
 2 yes       0.99 Sun Taxi                  yes   Tue   Jan       8
 3 yes       1.78 other                     no    Sat   Mar      22
 4 yes       0    Taxi Affiliation Services yes   Wed   Apr      15
 5 yes       0    Taxi Affiliation Services no    Sun   Jan      21
 6 yes       2.3  other                     no    Sat   Apr      21
 7 yes       6.35 Sun Taxi                  no    Wed   Mar      16
 8 yes       2.79 other                     no    Sun   Feb      14
 9 yes      16.6  other                     no    Sun   Apr      18
10 yes       0.02 Chicago Independents      yes   Sun   Apr      15
# ℹ 7,490 more rows

练习

set.seed(123)
taxi_split <- initial_split(taxi, prop = 0.8)
taxi_train <- training(taxi_split)
taxi_test <- testing(taxi_split)

nrow(taxi_train)
[1] 8000
nrow(taxi_test)
[1] 2000

Stratification

Use strata = tip

set.seed(123)
taxi_split <- initial_split(taxi, prop = 0.8, strata = tip)
taxi_split
<Training/Testing/Total>
<8000/2000/10000>

Stratification

Stratification often helps, with very little downside

模型类型

模型多种多样

  • lm for linear model

  • glm for generalized linear model (e.g. logistic regression)

  • glmnet for regularized regression

  • keras for regression using TensorFlow

  • stan for Bayesian regression

  • spark for large data sets

指定模型

logistic_reg()
Logistic Regression Model Specification (classification)

Computational engine: glm 

To specify a model

logistic_reg() %>%
  set_engine("glmnet")
Logistic Regression Model Specification (classification)

Computational engine: glmnet 
logistic_reg() %>%
  set_engine("stan")
Logistic Regression Model Specification (classification)

Computational engine: stan 
  • Choose a model
  • Specify an engine
  • Set the mode

To specify a model

decision_tree()
Decision Tree Model Specification (unknown mode)

Computational engine: rpart 

To specify a model

decision_tree() %>% 
  set_mode("classification")
Decision Tree Model Specification (classification)

Computational engine: rpart 



All available models are listed at https://www.tidymodels.org/find/parsnip/

Workflows

为什么要使用 workflow()?

  • 与基本的 R 工具相比,工作流能更好地处理新的因子水平
  • 除了公式之外,还可以使用其他的预处理器(更多关于高级 tidymodels 中的特征工程!)
  • 在使用多个模型时,它们可以帮助组织工作
  • 最重要的是,工作流涵盖了整个建模过程:fit()predict() 不仅适用于实际的模型拟合,还适用于预处理步骤

A model workflow

tree_spec <-
  decision_tree(cost_complexity = 0.002) %>% 
  set_mode("classification")

tree_spec %>% 
  fit(tip ~ ., data = taxi_train) 
parsnip model object

n= 8000 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 8000 616 yes (0.92300000 0.07700000)  
   2) distance>=14.12 2041  68 yes (0.96668300 0.03331700) *
   3) distance< 14.12 5959 548 yes (0.90803826 0.09196174)  
     6) distance< 5.275 5419 450 yes (0.91695885 0.08304115) *
     7) distance>=5.275 540  98 yes (0.81851852 0.18148148)  
      14) company=Chicago Independents,City Service,Sun Taxi,Taxi Affiliation Services,Taxicab Insurance Agency Llc,other 478  68 yes (0.85774059 0.14225941) *
      15) company=Flash Cab 62  30 yes (0.51612903 0.48387097)  
        30) dow=Thu 12   2 yes (0.83333333 0.16666667) *
        31) dow=Sun,Mon,Tue,Wed,Fri,Sat 50  22 no (0.44000000 0.56000000)  
          62) distance>=11.77 14   4 yes (0.71428571 0.28571429) *
          63) distance< 11.77 36  12 no (0.33333333 0.66666667) *

A model workflow

tree_spec <-
  decision_tree(cost_complexity = 0.002) %>% 
  set_mode("classification")

workflow() %>%
  add_formula(tip ~ .) %>%
  add_model(tree_spec) %>%
  fit(data = taxi_train) 
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Formula
Model: decision_tree()

── Preprocessor ────────────────────────────────────────────────────────────────
tip ~ .

── Model ───────────────────────────────────────────────────────────────────────
n= 8000 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 8000 616 yes (0.92300000 0.07700000)  
   2) distance>=14.12 2041  68 yes (0.96668300 0.03331700) *
   3) distance< 14.12 5959 548 yes (0.90803826 0.09196174)  
     6) distance< 5.275 5419 450 yes (0.91695885 0.08304115) *
     7) distance>=5.275 540  98 yes (0.81851852 0.18148148)  
      14) company=Chicago Independents,City Service,Sun Taxi,Taxi Affiliation Services,Taxicab Insurance Agency Llc,other 478  68 yes (0.85774059 0.14225941) *
      15) company=Flash Cab 62  30 yes (0.51612903 0.48387097)  
        30) dow=Thu 12   2 yes (0.83333333 0.16666667) *
        31) dow=Sun,Mon,Tue,Wed,Fri,Sat 50  22 no (0.44000000 0.56000000)  
          62) distance>=11.77 14   4 yes (0.71428571 0.28571429) *
          63) distance< 11.77 36  12 no (0.33333333 0.66666667) *

A model workflow

tree_spec <-
  decision_tree(cost_complexity = 0.002) %>% 
  set_mode("classification")

workflow(tip ~ ., tree_spec) %>% 
  fit(data = taxi_train) 
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Formula
Model: decision_tree()

── Preprocessor ────────────────────────────────────────────────────────────────
tip ~ .

── Model ───────────────────────────────────────────────────────────────────────
n= 8000 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 8000 616 yes (0.92300000 0.07700000)  
   2) distance>=14.12 2041  68 yes (0.96668300 0.03331700) *
   3) distance< 14.12 5959 548 yes (0.90803826 0.09196174)  
     6) distance< 5.275 5419 450 yes (0.91695885 0.08304115) *
     7) distance>=5.275 540  98 yes (0.81851852 0.18148148)  
      14) company=Chicago Independents,City Service,Sun Taxi,Taxi Affiliation Services,Taxicab Insurance Agency Llc,other 478  68 yes (0.85774059 0.14225941) *
      15) company=Flash Cab 62  30 yes (0.51612903 0.48387097)  
        30) dow=Thu 12   2 yes (0.83333333 0.16666667) *
        31) dow=Sun,Mon,Tue,Wed,Fri,Sat 50  22 no (0.44000000 0.56000000)  
          62) distance>=11.77 14   4 yes (0.71428571 0.28571429) *
          63) distance< 11.77 36  12 no (0.33333333 0.66666667) *

预测

How do you use your new tree_fit model?

tree_spec <-
  decision_tree(cost_complexity = 0.002) %>% 
  set_mode("classification")

tree_fit <-
  workflow(tip ~ ., tree_spec) %>% 
  fit(data = taxi_train) 

练习

Run:

predict(tree_fit, new_data = taxi_test)

Run:

augment(tree_fit, new_data = taxi_test)

What do you get?

tidymodels 的预测

  • 预测结果始终在一个 tibble
  • 列名和类型可读性强
  • new_data 中的行数和输出中的行数相同

理解模型

如何 理解tree_fit 模型?

Evaluating models: 预测值

augment(taxi_fit, new_data = taxi_train) %>%
  relocate(tip, .pred_class, .pred_yes, .pred_no)
# A tibble: 8,000 × 10
   tip   .pred_class .pred_yes .pred_no distance company local dow   month  hour
   <fct> <fct>           <dbl>    <dbl>    <dbl> <fct>   <fct> <fct> <fct> <int>
 1 yes   yes             0.967   0.0333    17.2  Chicag… no    Thu   Feb      16
 2 yes   yes             0.935   0.0646     0.88 City S… yes   Thu   Mar       8
 3 yes   yes             0.967   0.0333    18.1  other   no    Mon   Feb      18
 4 yes   yes             0.949   0.0507    12.2  Chicag… no    Sun   Mar      21
 5 yes   yes             0.821   0.179      0.94 Sun Ta… yes   Sat   Apr      23
 6 yes   yes             0.967   0.0333    17.5  Flash … no    Fri   Mar      12
 7 yes   yes             0.967   0.0333    17.7  other   no    Sun   Jan       6
 8 yes   yes             0.938   0.0616     1.85 Taxica… no    Fri   Apr      12
 9 yes   yes             0.938   0.0616     0.53 Sun Ta… no    Tue   Mar      18
10 yes   yes             0.931   0.0694     6.65 Taxica… no    Sun   Apr      11
# ℹ 7,990 more rows

Confusion matrix

Confusion matrix

augment(taxi_fit, new_data = taxi_train) %>%
  conf_mat(truth = tip, estimate = .pred_class)
          Truth
Prediction  yes   no
       yes 7341  536
       no    43   80

Confusion matrix

augment(taxi_fit, new_data = taxi_train) %>%
  conf_mat(truth = tip, estimate = .pred_class) %>%
  autoplot(type = "heatmap")

Metrics for model performance

augment(taxi_fit, new_data = taxi_train) %>%
  accuracy(truth = tip, estimate = .pred_class)
# A tibble: 1 × 3
  .metric  .estimator .estimate
  <chr>    <chr>          <dbl>
1 accuracy binary         0.928

二分类模型评估

模型的敏感性(Sensitivity)和特异性(Specificity)是评估二分类模型性能的重要指标:

  • 敏感性(Sensitivity),也称为真阳性率,衡量了模型正确识别正类别样本的能力。公式为真阳性数除以真阳性数加上假阴性数:

\[ \text{Sensitivity} = \frac{\text{True Positives}}{\text{True Positives} + \text{False Negatives}} \]

  • 特异性(Specificity),也称为真阴性率,衡量了模型正确识别负类别样本的能力。公式为真阴性数除以真阴性数加上假阳性数:

\[ \text{Specificity} = \frac{\text{True Negatives}}{\text{True Negatives} + \text{False Positives}} \]

在评估模型时,我们希望敏感性和特异性都很高。高敏感性表示模型能够捕获真正的正类别样本,高特异性表示模型能够准确排除负类别样本。

Metrics for model performance

augment(taxi_fit, new_data = taxi_train) %>%
  sensitivity(truth = tip, estimate = .pred_class)
# A tibble: 1 × 3
  .metric     .estimator .estimate
  <chr>       <chr>          <dbl>
1 sensitivity binary         0.994

Metrics for model performance

augment(taxi_fit, new_data = taxi_train) %>%
  sensitivity(truth = tip, estimate = .pred_class)
# A tibble: 1 × 3
  .metric     .estimator .estimate
  <chr>       <chr>          <dbl>
1 sensitivity binary         0.994


augment(taxi_fit, new_data = taxi_train) %>%
  specificity(truth = tip, estimate = .pred_class)
# A tibble: 1 × 3
  .metric     .estimator .estimate
  <chr>       <chr>          <dbl>
1 specificity binary         0.130

Metrics for model performance

We can use metric_set() to combine multiple calculations into one

taxi_metrics <- metric_set(accuracy, specificity, sensitivity)

augment(taxi_fit, new_data = taxi_train) %>%
  taxi_metrics(truth = tip, estimate = .pred_class)
# A tibble: 3 × 3
  .metric     .estimator .estimate
  <chr>       <chr>          <dbl>
1 accuracy    binary         0.928
2 specificity binary         0.130
3 sensitivity binary         0.994

Metrics for model performance

taxi_metrics <- metric_set(accuracy, specificity, sensitivity)

augment(taxi_fit, new_data = taxi_train) %>%
  group_by(local) %>%
  taxi_metrics(truth = tip, estimate = .pred_class)
# A tibble: 6 × 4
  local .metric     .estimator .estimate
  <fct> <chr>       <chr>          <dbl>
1 yes   accuracy    binary         0.898
2 no    accuracy    binary         0.935
3 yes   specificity binary         0.169
4 no    specificity binary         0.116
5 yes   sensitivity binary         0.987
6 no    sensitivity binary         0.996

Varying the threshold

ROC 曲线

  • ROC(Receiver Operating Characteristic)曲线用于评估二分类模型的性能,特别是在不同的阈值下比较模型的敏感性和特异性。
  • ROC曲线的横轴是假阳性率(False Positive Rate,FPR),纵轴是真阳性率(True Positive Rate,TPR)。在ROC曲线上,每个点对应于一个特定的阈值。通过改变阈值,我们可以观察到模型在不同条件下的表现。
  • ROC曲线越接近左上角(0,1)点,说明模型的性能越好,因为这表示在较低的假阳性率下,模型能够获得较高的真阳性率。ROC曲线下面积(Area Under the ROC Curve,AUC)也是评估模型性能的一种指标,AUC值越大表示模型性能越好。

ROC curve plot

augment(taxi_fit, new_data = taxi_train) %>% 
  roc_curve(truth = tip, .pred_yes) %>%
  autoplot()

过度拟合

过度拟合

Cross-validation

Cross-validation

Cross-validation

Cross-validation

vfold_cv(taxi_train) # v = 10 is default
#  10-fold cross-validation 
# A tibble: 10 × 2
   splits             id    
   <list>             <chr> 
 1 <split [7200/800]> Fold01
 2 <split [7200/800]> Fold02
 3 <split [7200/800]> Fold03
 4 <split [7200/800]> Fold04
 5 <split [7200/800]> Fold05
 6 <split [7200/800]> Fold06
 7 <split [7200/800]> Fold07
 8 <split [7200/800]> Fold08
 9 <split [7200/800]> Fold09
10 <split [7200/800]> Fold10

Cross-validation

What is in this?

taxi_folds <- vfold_cv(taxi_train)
taxi_folds$splits[1:3]
[[1]]
<Analysis/Assess/Total>
<7200/800/8000>

[[2]]
<Analysis/Assess/Total>
<7200/800/8000>

[[3]]
<Analysis/Assess/Total>
<7200/800/8000>

Cross-validation

vfold_cv(taxi_train, v = 5)
#  5-fold cross-validation 
# A tibble: 5 × 2
  splits              id   
  <list>              <chr>
1 <split [6400/1600]> Fold1
2 <split [6400/1600]> Fold2
3 <split [6400/1600]> Fold3
4 <split [6400/1600]> Fold4
5 <split [6400/1600]> Fold5

Cross-validation

vfold_cv(taxi_train, strata = tip)
#  10-fold cross-validation using stratification 
# A tibble: 10 × 2
   splits             id    
   <list>             <chr> 
 1 <split [7200/800]> Fold01
 2 <split [7200/800]> Fold02
 3 <split [7200/800]> Fold03
 4 <split [7200/800]> Fold04
 5 <split [7200/800]> Fold05
 6 <split [7200/800]> Fold06
 7 <split [7200/800]> Fold07
 8 <split [7200/800]> Fold08
 9 <split [7200/800]> Fold09
10 <split [7200/800]> Fold10

Stratification often helps, with very little downside

Cross-validation

We’ll use this setup:

set.seed(123)
taxi_folds <- vfold_cv(taxi_train, v = 10, strata = tip)
taxi_folds
#  10-fold cross-validation using stratification 
# A tibble: 10 × 2
   splits             id    
   <list>             <chr> 
 1 <split [7200/800]> Fold01
 2 <split [7200/800]> Fold02
 3 <split [7200/800]> Fold03
 4 <split [7200/800]> Fold04
 5 <split [7200/800]> Fold05
 6 <split [7200/800]> Fold06
 7 <split [7200/800]> Fold07
 8 <split [7200/800]> Fold08
 9 <split [7200/800]> Fold09
10 <split [7200/800]> Fold10

Set the seed when creating resamples

Fit our model to the resamples

taxi_res <- fit_resamples(taxi_wflow, taxi_folds)
taxi_res
# Resampling results
# 10-fold cross-validation using stratification 
# A tibble: 10 × 4
   splits             id     .metrics         .notes          
   <list>             <chr>  <list>           <list>          
 1 <split [7200/800]> Fold01 <tibble [2 × 4]> <tibble [0 × 3]>
 2 <split [7200/800]> Fold02 <tibble [2 × 4]> <tibble [0 × 3]>
 3 <split [7200/800]> Fold03 <tibble [2 × 4]> <tibble [0 × 3]>
 4 <split [7200/800]> Fold04 <tibble [2 × 4]> <tibble [0 × 3]>
 5 <split [7200/800]> Fold05 <tibble [2 × 4]> <tibble [0 × 3]>
 6 <split [7200/800]> Fold06 <tibble [2 × 4]> <tibble [0 × 3]>
 7 <split [7200/800]> Fold07 <tibble [2 × 4]> <tibble [0 × 3]>
 8 <split [7200/800]> Fold08 <tibble [2 × 4]> <tibble [0 × 3]>
 9 <split [7200/800]> Fold09 <tibble [2 × 4]> <tibble [0 × 3]>
10 <split [7200/800]> Fold10 <tibble [2 × 4]> <tibble [0 × 3]>

Evaluating model performance

taxi_res %>%
  collect_metrics()
# A tibble: 2 × 6
  .metric  .estimator  mean     n std_err .config             
  <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
1 accuracy binary     0.915    10 0.00309 Preprocessor1_Model1
2 roc_auc  binary     0.624    10 0.0105  Preprocessor1_Model1

We can reliably measure performance using only the training data 🎉

Comparing metrics

How do the metrics from resampling compare to the metrics from training and testing?

taxi_res %>%
  collect_metrics() %>% 
  select(.metric, mean, n)
# A tibble: 2 × 3
  .metric   mean     n
  <chr>    <dbl> <int>
1 accuracy 0.915    10
2 roc_auc  0.624    10

The ROC AUC previously was

  • 0.69 for the training set
  • 0.64 for test set

Remember that:

⚠️ the training set gives you overly optimistic metrics

⚠️ the test set is precious

Evaluating model performance

# Save the assessment set results
ctrl_taxi <- control_resamples(save_pred = TRUE)
taxi_res <- fit_resamples(taxi_wflow, taxi_folds, control = ctrl_taxi)

taxi_res
# Resampling results
# 10-fold cross-validation using stratification 
# A tibble: 10 × 5
   splits             id     .metrics         .notes           .predictions
   <list>             <chr>  <list>           <list>           <list>      
 1 <split [7200/800]> Fold01 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 2 <split [7200/800]> Fold02 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 3 <split [7200/800]> Fold03 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 4 <split [7200/800]> Fold04 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 5 <split [7200/800]> Fold05 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 6 <split [7200/800]> Fold06 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 7 <split [7200/800]> Fold07 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 8 <split [7200/800]> Fold08 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 9 <split [7200/800]> Fold09 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
10 <split [7200/800]> Fold10 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    

Evaluating model performance

# Save the assessment set results
taxi_preds <- collect_predictions(taxi_res)
taxi_preds
# A tibble: 8,000 × 7
   id     .pred_yes .pred_no  .row .pred_class tip   .config             
   <chr>      <dbl>    <dbl> <int> <fct>       <fct> <chr>               
 1 Fold01     0.938   0.0615    14 yes         yes   Preprocessor1_Model1
 2 Fold01     0.946   0.0544    19 yes         yes   Preprocessor1_Model1
 3 Fold01     0.973   0.0269    33 yes         yes   Preprocessor1_Model1
 4 Fold01     0.903   0.0971    43 yes         yes   Preprocessor1_Model1
 5 Fold01     0.973   0.0269    74 yes         yes   Preprocessor1_Model1
 6 Fold01     0.903   0.0971   103 yes         yes   Preprocessor1_Model1
 7 Fold01     0.915   0.0851   104 yes         no    Preprocessor1_Model1
 8 Fold01     0.903   0.0971   124 yes         yes   Preprocessor1_Model1
 9 Fold01     0.667   0.333    126 yes         yes   Preprocessor1_Model1
10 Fold01     0.949   0.0510   128 yes         yes   Preprocessor1_Model1
# ℹ 7,990 more rows

Evaluating model performance

taxi_preds %>% 
  group_by(id) %>%
  taxi_metrics(truth = tip, estimate = .pred_class)
# A tibble: 30 × 4
   id     .metric  .estimator .estimate
   <chr>  <chr>    <chr>          <dbl>
 1 Fold01 accuracy binary         0.905
 2 Fold02 accuracy binary         0.925
 3 Fold03 accuracy binary         0.926
 4 Fold04 accuracy binary         0.915
 5 Fold05 accuracy binary         0.902
 6 Fold06 accuracy binary         0.912
 7 Fold07 accuracy binary         0.906
 8 Fold08 accuracy binary         0.91 
 9 Fold09 accuracy binary         0.918
10 Fold10 accuracy binary         0.931
# ℹ 20 more rows

Where are the fitted models?

taxi_res
# Resampling results
# 10-fold cross-validation using stratification 
# A tibble: 10 × 5
   splits             id     .metrics         .notes           .predictions
   <list>             <chr>  <list>           <list>           <list>      
 1 <split [7200/800]> Fold01 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 2 <split [7200/800]> Fold02 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 3 <split [7200/800]> Fold03 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 4 <split [7200/800]> Fold04 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 5 <split [7200/800]> Fold05 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 6 <split [7200/800]> Fold06 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 7 <split [7200/800]> Fold07 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 8 <split [7200/800]> Fold08 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 9 <split [7200/800]> Fold09 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
10 <split [7200/800]> Fold10 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    

Bootstrapping

Bootstrapping

set.seed(3214)
bootstraps(taxi_train)
# Bootstrap sampling 
# A tibble: 25 × 2
   splits              id         
   <list>              <chr>      
 1 <split [8000/2902]> Bootstrap01
 2 <split [8000/2916]> Bootstrap02
 3 <split [8000/3004]> Bootstrap03
 4 <split [8000/2979]> Bootstrap04
 5 <split [8000/2961]> Bootstrap05
 6 <split [8000/2962]> Bootstrap06
 7 <split [8000/3026]> Bootstrap07
 8 <split [8000/2926]> Bootstrap08
 9 <split [8000/2972]> Bootstrap09
10 <split [8000/2972]> Bootstrap10
# ℹ 15 more rows

Monte Carlo Cross-Validation

set.seed(322)
mc_cv(taxi_train, times = 10)
# Monte Carlo cross-validation (0.75/0.25) with 10 resamples  
# A tibble: 10 × 2
   splits              id        
   <list>              <chr>     
 1 <split [6000/2000]> Resample01
 2 <split [6000/2000]> Resample02
 3 <split [6000/2000]> Resample03
 4 <split [6000/2000]> Resample04
 5 <split [6000/2000]> Resample05
 6 <split [6000/2000]> Resample06
 7 <split [6000/2000]> Resample07
 8 <split [6000/2000]> Resample08
 9 <split [6000/2000]> Resample09
10 <split [6000/2000]> Resample10

Validation set

set.seed(853)
taxi_val_split <- initial_validation_split(taxi, strata = tip)
validation_set(taxi_val_split)
# A tibble: 1 × 2
  splits              id        
  <list>              <chr>     
1 <split [6000/2000]> validation

Create a random forest model

rf_spec <- rand_forest(trees = 1000, mode = "classification")
rf_spec
Random Forest Model Specification (classification)

Main Arguments:
  trees = 1000

Computational engine: ranger 

Create a random forest model

rf_wflow <- workflow(tip ~ ., rf_spec)
rf_wflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Formula
Model: rand_forest()

── Preprocessor ────────────────────────────────────────────────────────────────
tip ~ .

── Model ───────────────────────────────────────────────────────────────────────
Random Forest Model Specification (classification)

Main Arguments:
  trees = 1000

Computational engine: ranger 

Evaluating model performance

ctrl_taxi <- control_resamples(save_pred = TRUE)

# Random forest uses random numbers so set the seed first

set.seed(2)
rf_res <- fit_resamples(rf_wflow, taxi_folds, control = ctrl_taxi)
collect_metrics(rf_res)
# A tibble: 2 × 6
  .metric  .estimator  mean     n std_err .config             
  <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
1 accuracy binary     0.923    10 0.00317 Preprocessor1_Model1
2 roc_auc  binary     0.616    10 0.0147  Preprocessor1_Model1

The whole game - status update

The final fit

# taxi_split has train + test info
final_fit <- last_fit(rf_wflow, taxi_split) 

final_fit
# Resampling results
# Manual resampling 
# A tibble: 1 × 6
  splits              id               .metrics .notes   .predictions .workflow 
  <list>              <chr>            <list>   <list>   <list>       <list>    
1 <split [8000/2000]> train/test split <tibble> <tibble> <tibble>     <workflow>

何为final_fit?

collect_metrics(final_fit)
# A tibble: 2 × 4
  .metric  .estimator .estimate .config             
  <chr>    <chr>          <dbl> <chr>               
1 accuracy binary         0.914 Preprocessor1_Model1
2 roc_auc  binary         0.638 Preprocessor1_Model1

These are metrics computed with the test set

何为final_fit?

collect_predictions(final_fit)
# A tibble: 2,000 × 7
   id               .pred_yes .pred_no  .row .pred_class tip   .config          
   <chr>                <dbl>    <dbl> <int> <fct>       <fct> <chr>            
 1 train/test split     0.957   0.0426     4 yes         yes   Preprocessor1_Mo…
 2 train/test split     0.938   0.0621    10 yes         yes   Preprocessor1_Mo…
 3 train/test split     0.958   0.0416    19 yes         yes   Preprocessor1_Mo…
 4 train/test split     0.894   0.106     23 yes         yes   Preprocessor1_Mo…
 5 train/test split     0.943   0.0573    28 yes         yes   Preprocessor1_Mo…
 6 train/test split     0.979   0.0213    34 yes         yes   Preprocessor1_Mo…
 7 train/test split     0.954   0.0463    35 yes         yes   Preprocessor1_Mo…
 8 train/test split     0.928   0.0722    38 yes         yes   Preprocessor1_Mo…
 9 train/test split     0.985   0.0147    40 yes         yes   Preprocessor1_Mo…
10 train/test split     0.948   0.0523    42 yes         no    Preprocessor1_Mo…
# ℹ 1,990 more rows

何为final_fit?

extract_workflow(final_fit)
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Formula
Model: rand_forest()

── Preprocessor ────────────────────────────────────────────────────────────────
tip ~ .

── Model ───────────────────────────────────────────────────────────────────────
Ranger result

Call:
 ranger::ranger(x = maybe_data_frame(x), y = y, num.trees = ~1000,      num.threads = 1, verbose = FALSE, seed = sample.int(10^5,          1), probability = TRUE) 

Type:                             Probability estimation 
Number of trees:                  1000 
Sample size:                      8000 
Number of independent variables:  6 
Mtry:                             2 
Target node size:                 10 
Variable importance mode:         none 
Splitrule:                        gini 
OOB prediction error (Brier s.):  0.07069778 

Use this for prediction on new data, like for deploying

Tuning models - Specifying tuning parameters

rf_spec <- rand_forest(min_n = tune()) %>% 
  set_mode("classification")

rf_wflow <- workflow(tip ~ ., rf_spec)
rf_wflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Formula
Model: rand_forest()

── Preprocessor ────────────────────────────────────────────────────────────────
tip ~ .

── Model ───────────────────────────────────────────────────────────────────────
Random Forest Model Specification (classification)

Main Arguments:
  min_n = tune()

Computational engine: ranger 

Try out multiple values

tune_grid() works similar to fit_resamples() but covers multiple parameter values:

set.seed(22)
rf_res <- tune_grid(
  rf_wflow,
  taxi_folds,
  grid = 5
)

Compare results

Inspecting results and selecting the best-performing hyperparameter(s):

show_best(rf_res)
# A tibble: 5 × 7
  min_n .metric .estimator  mean     n std_err .config             
  <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
1    33 roc_auc binary     0.623    10  0.0149 Preprocessor1_Model1
2    31 roc_auc binary     0.622    10  0.0154 Preprocessor1_Model3
3    21 roc_auc binary     0.620    10  0.0149 Preprocessor1_Model4
4    13 roc_auc binary     0.617    10  0.0137 Preprocessor1_Model5
5     6 roc_auc binary     0.611    10  0.0156 Preprocessor1_Model2
best_parameter <- select_best(rf_res)
best_parameter
# A tibble: 1 × 2
  min_n .config             
  <int> <chr>               
1    33 Preprocessor1_Model1

collect_metrics() and autoplot() are also available.

The final fit

rf_wflow <- finalize_workflow(rf_wflow, best_parameter)

final_fit <- last_fit(rf_wflow, taxi_split) 

collect_metrics(final_fit)
# A tibble: 2 × 4
  .metric  .estimator .estimate .config             
  <chr>    <chr>          <dbl> <chr>               
1 accuracy binary         0.913 Preprocessor1_Model1
2 roc_auc  binary         0.648 Preprocessor1_Model1

实践部分

数据

require(tidyverse)
sitedf <- readr::read_csv("https://www.epa.gov/sites/default/files/2014-01/nla2007_sampledlakeinformation_20091113.csv") |>
  select(SITE_ID,
    lon = LON_DD, 
    lat = LAT_DD, 
    name = LAKENAME, 
    area = LAKEAREA, 
    zmax = DEPTHMAX
    ) |>
  group_by(SITE_ID) |>
  summarize(lon = mean(lon, na.rm = TRUE),
    lat = mean(lat, na.rm = TRUE),
    name = unique(name),
    area = mean(area, na.rm = TRUE),
    zmax = mean(zmax, na.rm = TRUE))


visitdf <- readr::read_csv("https://www.epa.gov/sites/default/files/2013-09/nla2007_profile_20091008.csv") |>
  select(SITE_ID,
    date = DATE_PROFILE,
    year = YEAR,
    visit = VISIT_NO
  ) |>
  distinct()



waterchemdf <- readr::read_csv("https://www.epa.gov/sites/default/files/2013-09/nla2007_profile_20091008.csv") |>
  select(SITE_ID,
    date = DATE_PROFILE,
    depth = DEPTH,
    temp = TEMP_FIELD,
    do = DO_FIELD,
    ph = PH_FIELD,
    cond = COND_FIELD,
  )

sddf <- readr::read_csv("https://www.epa.gov/sites/default/files/2014-10/nla2007_secchi_20091008.csv") |>
  select(SITE_ID, 
    date = DATE_SECCHI, 
    sd = SECMEAN,
    clear_to_bottom = CLEAR_TO_BOTTOM
  )

trophicdf <- readr::read_csv("https://www.epa.gov/sites/default/files/2014-10/nla2007_trophic_conditionestimate_20091123.csv") |>
  select(SITE_ID,
    visit = VISIT_NO,
    tp = PTL,
    tn = NTL,
    chla = CHLA) |>
  left_join(visitdf, by = c("SITE_ID", "visit")) |>
  select(-year, -visit) |>
  group_by(SITE_ID, date) |>
  summarize(tp = mean(tp, na.rm = TRUE),
    tn = mean(tn, na.rm = TRUE),
    chla = mean(chla, na.rm = TRUE)
  )



phytodf <- readr::read_csv("https://www.epa.gov/sites/default/files/2014-10/nla2007_phytoplankton_softalgaecount_20091023.csv") |>
  select(SITE_ID,
    date = DATEPHYT,
    depth = SAMPLE_DEPTH,
    phyta = DIVISION,
    genus = GENUS,
    species = SPECIES,
    tax = TAXANAME,
    abund = ABUND) |>
  mutate(phyta = gsub(" .*$", "", phyta)) |>
  filter(!is.na(genus)) |>
  group_by(SITE_ID, date, depth, phyta, genus) |>
  summarize(abund = sum(abund, na.rm = TRUE)) |>
  nest(phytodf = -c(SITE_ID, date))

envdf <- waterchemdf |>
  filter(depth < 2) |>
  select(-depth) |>
  group_by(SITE_ID, date) |>
  summarise_all(~mean(., na.rm = TRUE)) |>
  ungroup() |>
  left_join(sddf, by = c("SITE_ID", "date")) |>
  left_join(trophicdf, by = c("SITE_ID", "date"))

nla <- envdf |>
  left_join(phytodf) |>
  left_join(sitedf, by = "SITE_ID") |>
  filter(!purrr::map_lgl(phytodf, is.null)) |>
  mutate(cyanophyta = purrr::map(phytodf, ~ .x |>
    dplyr::filter(phyta == "Cyanophyta") |>
    summarize(cyanophyta = sum(abund, na.rm = TRUE))
  )) |>
  unnest(cyanophyta) |>
  select(-phyta) |>
  mutate(clear_to_bottom = ifelse(is.na(clear_to_bottom), TRUE, FALSE))


# library(rmdify)
# library(dwfun)
# dwfun::init()

数据

skimr::skim(nla)
Data summary
Name nla
Number of rows 1208
Number of columns 19
_______________________
Column type frequency:
character 3
list 1
logical 1
numeric 14
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
SITE_ID 0 1.00 12 24 0 1114 0
date 0 1.00 8 10 0 116 0
name 44 0.96 5 48 0 990 0

Variable type: list

skim_variable n_missing complete_rate n_unique min_length max_length
phytodf 0 1 1207 4 4

Variable type: logical

skim_variable n_missing complete_rate mean count
clear_to_bottom 0 1 0.96 TRU: 1154, FAL: 54

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
temp 1 1.00 24.13 4.06 10.95 21.43 24.53 27.20 37.73 ▁▅▇▅▁
do 54 0.96 7.84 1.97 0.77 6.85 7.88 8.70 21.00 ▁▇▂▁▁
ph 3 1.00 8.08 0.88 4.10 7.55 8.25 8.64 10.33 ▁▁▅▇▂
cond 247 0.80 714.37 2499.57 3.00 86.93 219.40 471.00 42487.75 ▇▁▁▁▁
sd 62 0.95 2.15 2.49 0.04 0.60 1.35 2.75 36.71 ▇▁▁▁▁
tp 0 1.00 112.77 301.34 1.00 10.00 25.00 90.00 4865.00 ▇▁▁▁▁
tn 0 1.00 1174.39 2061.71 5.00 317.75 584.00 1174.25 26100.00 ▇▁▁▁▁
chla 5 1.00 29.91 69.27 0.07 2.96 7.79 26.08 936.00 ▇▁▁▁▁
lon 0 1.00 -94.34 14.08 -124.25 -103.12 -94.48 -84.32 -67.70 ▃▃▇▆▃
lat 0 1.00 40.55 5.02 26.94 37.12 41.31 44.64 48.98 ▁▃▆▇▆
area 0 1.00 12.01 78.53 0.04 0.24 0.70 2.87 1674.90 ▇▁▁▁▁
zmax 0 1.00 9.41 10.13 0.50 2.90 5.90 12.00 97.00 ▇▁▁▁▁
depth 9 0.99 1.58 0.59 0.08 1.13 2.00 2.00 2.00 ▁▂▂▁▇
cyanophyta 0 1.00 38382.63 191373.91 0.66 1200.61 5483.81 23504.81 4982222.22 ▇▁▁▁▁

简单模型

nla |>
  filter(tp > 1) |>
  ggplot(aes(tn, tp)) +
geom_point() +
geom_smooth(method = "lm") +
scale_x_log10(breaks = scales::trans_breaks("log10", function(x) 10^x),
               labels = scales::trans_format("log10", scales::math_format(10^.x))) +
scale_y_log10(breaks = scales::trans_breaks("log10", function(x) 10^x),
               labels = scales::trans_format("log10", scales::math_format(10^.x)))
m1 <- lm(log10(tp) ~ log10(tn), data = nla)

summary(m1)

Call:
lm(formula = log10(tp) ~ log10(tn), data = nla)

Residuals:
    Min      1Q  Median      3Q     Max 
-1.8063 -0.2360  0.0125  0.2245  1.9140 

Coefficients:
            Estimate Std. Error t value Pr(>|t|)    
(Intercept) -1.92315    0.07166  -26.84   <2e-16 ***
log10(tn)    1.21700    0.02528   48.13   <2e-16 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Residual standard error: 0.405 on 1206 degrees of freedom
Multiple R-squared:  0.6576,    Adjusted R-squared:  0.6574 
F-statistic:  2317 on 1 and 1206 DF,  p-value: < 2.2e-16

复杂指标

nla |>
  filter(tp > 1) |>
  ggplot(aes(tp, cyanophyta)) +
geom_point() +
geom_smooth(method = "lm") +
scale_x_log10(breaks = scales::trans_breaks("log10", function(x) 10^x),
               labels = scales::trans_format("log10", scales::math_format(10^.x))) +
scale_y_log10(breaks = scales::trans_breaks("log10", function(x) 10^x),
               labels = scales::trans_format("log10", scales::math_format(10^.x)))
m2 <- lm(log10(cyanophyta) ~ log10(tp), data = nla)

summary(m2)

Call:
lm(formula = log10(cyanophyta) ~ log10(tp), data = nla)

Residuals:
    Min      1Q  Median      3Q     Max 
-5.1551 -0.5128  0.1407  0.6546  3.1811 

Coefficients:
            Estimate Std. Error t value Pr(>|t|)    
(Intercept)  2.82739    0.06181   45.74   <2e-16 ***
log10(tp)    0.58577    0.03784   15.48   <2e-16 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Residual standard error: 0.9095 on 1206 degrees of freedom
Multiple R-squared:  0.1658,    Adjusted R-squared:  0.1651 
F-statistic: 239.7 on 1 and 1206 DF,  p-value: < 2.2e-16

tidymodels - Data split

(nla_split <- rsample::initial_split(nla, prop = 0.7, strata = zmax))
<Training/Testing/Total>
<844/364/1208>
(nla_train <- training(nla_split))
# A tibble: 844 × 19
   SITE_ID     date   temp     do    ph   cond    sd clear_to_bottom    tp    tn
   <chr>       <chr> <dbl>  <dbl> <dbl>  <dbl> <dbl> <lgl>           <dbl> <dbl>
 1 NLA06608-0… 6/14…  22.2 NaN     5.52   62.1  0.55 TRUE               36   695
 2 NLA06608-0… 8/29…  30.1   7.5   8.35 1128.   0.71 TRUE               43   738
 3 NLA06608-0… 9/6/…  30.0   9     8.4  1220.   0.49 TRUE               50   843
 4 NLA06608-0… 9/14…  22.7 NaN     5.8    44.5  1.05 TRUE               20   264
 5 NLA06608-0… 9/4/…  25.1 NaN     5.26   45.5  0.65 TRUE               28   384
 6 NLA06608-0… 8/23…  18.1   8.4   9.4  9052.   0.63 TRUE              175  4456
 7 NLA06608-0… 8/6/…  21.6   7.6   9.4  9080    0.85 TRUE              175  4147
 8 NLA06608-0… 6/11…  22.9   4.75  8.38 3373    0.35 TRUE              801  7047
 9 NLA06608-0… 8/7/…  21.2   4.71  9.03 3125.   0.55 TRUE             1376  6578
10 NLA06608-0… 8/9/…  29.4   7.8   8.2   168    0.95 TRUE               12   349
# ℹ 834 more rows
# ℹ 9 more variables: chla <dbl>, phytodf <list>, lon <dbl>, lat <dbl>,
#   name <chr>, area <dbl>, zmax <dbl>, depth <dbl>, cyanophyta <dbl>
(nla_test <- testing(nla_split))
# A tibble: 364 × 19
   SITE_ID      date   temp     do    ph  cond    sd clear_to_bottom    tp    tn
   <chr>        <chr> <dbl>  <dbl> <dbl> <dbl> <dbl> <lgl>           <dbl> <dbl>
 1 NLA06608-00… 7/31…  16.3   8.25  8.15 152.   6.4  TRUE                6   151
 2 NLA06608-00… 7/23…  24.8 NaN     5.07  46.0  0.45 TRUE               22   469
 3 NLA06608-00… 7/17…  25.3   8.56  8.15  77    3.21 TRUE                7   184
 4 NLA06608-00… 8/30…  24.5   8.68  7.84  83    4.1  TRUE                4   223
 5 NLA06608-00… 6/13…  26.8   5.4   7.43 196.   0.31 TRUE              159  1026
 6 NLA06608-00… 9/18…  24.1   6.87  7.78 221.   0.27 TRUE              142  1052
 7 NLA06608-00… 7/10…  24.4   8     7.9  NaN    0.37 TRUE              109   470
 8 NLA06608-00… 7/2/…  27     7.45  8.3  215    1.07 TRUE               20   466
 9 NLA06608-00… 7/11…  27.8   7.1   7.15 176.   0.9  TRUE               35   860
10 NLA06608-00… 8/14…  32.8   8.5   8.8  136.   1.08 TRUE               29   943
# ℹ 354 more rows
# ℹ 9 more variables: chla <dbl>, phytodf <list>, lon <dbl>, lat <dbl>,
#   name <chr>, area <dbl>, zmax <dbl>, depth <dbl>, cyanophyta <dbl>

tidymodels - recipe

nla_formula <- as.formula("cyanophyta ~ temp + do + ph + cond + sd + tp + tn + chla + clear_to_bottom")
# nla_formula <- as.formula("cyanophyta ~ temp + do + ph + cond + sd + tp + tn")
nla_recipe <- recipes::recipe(nla_formula, data = nla_train) |>
  recipes::step_string2factor(all_nominal()) |>
  recipes::step_nzv(all_nominal()) |>
  recipes::step_log(chla, cyanophyta, base = 10) |>
  recipes::step_normalize(all_numeric_predictors()) |>
  prep()
nla_recipe

tidymodels - cross validation

nla_cv <- recipes::bake(
    nla_recipe, 
    new_data = training(nla_split)
  ) |>
  rsample::vfold_cv(v = 10)
nla_cv
#  10-fold cross-validation 
# A tibble: 10 × 2
   splits           id    
   <list>           <chr> 
 1 <split [759/85]> Fold01
 2 <split [759/85]> Fold02
 3 <split [759/85]> Fold03
 4 <split [759/85]> Fold04
 5 <split [760/84]> Fold05
 6 <split [760/84]> Fold06
 7 <split [760/84]> Fold07
 8 <split [760/84]> Fold08
 9 <split [760/84]> Fold09
10 <split [760/84]> Fold10

tidymodels - Model specification

xgboost_model <- parsnip::boost_tree(
  mode = "regression",
  trees = 1000,
  min_n = tune(),
  tree_depth = tune(),
  learn_rate = tune(),
  loss_reduction = tune()
) |>
  set_engine("xgboost", objective = "reg:squarederror")
xgboost_model
Boosted Tree Model Specification (regression)

Main Arguments:
  trees = 1000
  min_n = tune()
  tree_depth = tune()
  learn_rate = tune()
  loss_reduction = tune()

Engine-Specific Arguments:
  objective = reg:squarederror

Computational engine: xgboost 

tidymodels - Grid specification

# grid specification
xgboost_params <- dials::parameters(
  min_n(),
  tree_depth(),
  learn_rate(),
  loss_reduction()
)
xgboost_params
Collection of 4 parameters for tuning

     identifier           type    object
          min_n          min_n nparam[+]
     tree_depth     tree_depth nparam[+]
     learn_rate     learn_rate nparam[+]
 loss_reduction loss_reduction nparam[+]

tidymodels - Grid specification

xgboost_grid <- dials::grid_max_entropy(
  xgboost_params, 
  size = 60
)
knitr::kable(head(xgboost_grid))
min_n tree_depth learn_rate loss_reduction
22 9 0.0000000 0.0000024
27 13 0.0000721 0.0000000
28 4 0.0002446 0.0000000
32 4 0.0000000 0.0000000
10 11 0.0000000 0.1677615
7 15 0.0000002 0.0000000

tidymodels - Workflow

xgboost_wf <- workflows::workflow() |>
  add_model(xgboost_model) |>
  add_formula(nla_formula)
xgboost_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Formula
Model: boost_tree()

── Preprocessor ────────────────────────────────────────────────────────────────
cyanophyta ~ temp + do + ph + cond + sd + tp + tn + chla + clear_to_bottom

── Model ───────────────────────────────────────────────────────────────────────
Boosted Tree Model Specification (regression)

Main Arguments:
  trees = 1000
  min_n = tune()
  tree_depth = tune()
  learn_rate = tune()
  loss_reduction = tune()

Engine-Specific Arguments:
  objective = reg:squarederror

Computational engine: xgboost 

tidymodels - Tune

# hyperparameter tuning
if (FALSE) {
  xgboost_tuned <- tune::tune_grid(
    object = xgboost_wf,
    resamples = nla_cv,
    grid = xgboost_grid,
    metrics = yardstick::metric_set(rmse, rsq, mae),
    control = tune::control_grid(verbose = TRUE)
  )
saveRDS(xgboost_tuned, "./xgboost_tuned.RDS")
}
xgboost_tuned <- readRDS("./xgboost_tuned.RDS")

tidymodels - Best model

xgboost_tuned |>
  tune::show_best(metric = "rmse") |>
  knitr::kable()
min_n tree_depth learn_rate loss_reduction .metric .estimator mean n std_err .config
17 3 0.0767369 6.9949310 rmse standard 1.824333 10 0.0461643 Preprocessor1_Model21
19 8 0.0361629 24.8793941 rmse standard 1.825676 10 0.0468745 Preprocessor1_Model16
10 1 0.0093868 0.4234282 rmse standard 1.827963 10 0.0478933 Preprocessor1_Model29
35 10 0.0103573 30.7044800 rmse standard 1.836934 10 0.0492418 Preprocessor1_Model26
12 1 0.0370675 0.0000002 rmse standard 1.840259 10 0.0447300 Preprocessor1_Model04

tidymodels - Best model

xgboost_tuned |>
  collect_metrics()
# A tibble: 180 × 10
   min_n tree_depth learn_rate loss_reduction .metric .estimator  mean     n
   <int>      <int>      <dbl>          <dbl> <chr>   <chr>      <dbl> <int>
 1     3          2  0.0000362   0.0000000523 mae     standard   7.76     10
 2     3          2  0.0000362   0.0000000523 rmse    standard   8.09     10
 3     3          2  0.0000362   0.0000000523 rsq     standard   0.328    10
 4    31          1  0.0000481   0.000000376  mae     standard   7.67     10
 5    31          1  0.0000481   0.000000376  rmse    standard   8.00     10
 6    31          1  0.0000481   0.000000376  rsq     standard   0.218    10
 7    39         12  0.0586      0.0316       mae     standard   1.57     10
 8    39         12  0.0586      0.0316       rmse    standard   1.99     10
 9    39         12  0.0586      0.0316       rsq     standard   0.311    10
10    12          1  0.0371      0.000000200  mae     standard   1.46     10
# ℹ 170 more rows
# ℹ 2 more variables: std_err <dbl>, .config <chr>

tidymodels - Best model

xgboost_tuned |>
  autoplot()

tidymodels - Best model

xgboost_best_params <- xgboost_tuned |>
  tune::select_best("rmse")

knitr::kable(xgboost_best_params)
min_n tree_depth learn_rate loss_reduction .config
17 3 0.0767369 6.994931 Preprocessor1_Model21

tidymodels - Final model

xgboost_model_final <- xgboost_model |>
  finalize_model(xgboost_best_params)
xgboost_model_final
Boosted Tree Model Specification (regression)

Main Arguments:
  trees = 1000
  min_n = 17
  tree_depth = 3
  learn_rate = 0.0767368959881715
  loss_reduction = 6.99493103991011

Engine-Specific Arguments:
  objective = reg:squarederror

Computational engine: xgboost 

tidymodels - Train evaluation

(train_processed <- bake(nla_recipe, new_data = nla_train))
# A tibble: 844 × 10
     temp       do     ph   cond     sd     tp     tn    chla clear_to_bottom
    <dbl>    <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>   <dbl> <lgl>          
 1 -0.501 NaN      -2.90  -0.278 -0.689 -0.317 -0.243 -0.581  TRUE           
 2  1.47   -0.157   0.304  0.191 -0.616 -0.286 -0.222  0.395  TRUE           
 3  1.44    0.569   0.360  0.232 -0.717 -0.255 -0.172  0.214  TRUE           
 4 -0.390 NaN      -2.58  -0.286 -0.460 -0.388 -0.450 -0.0359 TRUE           
 5  0.221 NaN      -3.20  -0.286 -0.643 -0.352 -0.393  0.358  TRUE           
 6 -1.53    0.279   1.49   3.68  -0.653  0.300  1.57   0.521  TRUE           
 7 -0.666  -0.108   1.49   3.70  -0.552  0.300  1.42  -0.0396 TRUE           
 8 -0.330  -1.49    0.338  1.18  -0.781  3.08   2.81  -0.549  TRUE           
 9 -0.748  -1.51    1.07   1.07  -0.689  5.63   2.59   1.03   TRUE           
10  1.29   -0.0116  0.134 -0.232 -0.506 -0.423 -0.410 -0.793  TRUE           
# ℹ 834 more rows
# ℹ 1 more variable: cyanophyta <dbl>

tidymodels - Train data

train_prediction <- xgboost_model_final |>
  # fit the model on all the training data
  fit(
    formula = nla_formula, 
    data    = train_processed
  ) |>
  # predict the sale prices for the training data
  predict(new_data = train_processed) |>
  bind_cols(nla_train |>
  mutate(.obs = log10(cyanophyta)))
xgboost_score_train <- 
  train_prediction |>
  yardstick::metrics(.obs, .pred) |>
  mutate(.estimate = format(round(.estimate, 2), big.mark = ","))
knitr::kable(xgboost_score_train)
.metric .estimator .estimate
rmse standard 0.81
rsq standard 0.38
mae standard 0.64

tidymodels - train evaluation

train_prediction |>
  ggplot(aes(.pred, .obs)) +
geom_point() +
geom_smooth(method = "lm")

tidymodels - test data

test_processed  <- bake(nla_recipe, new_data = nla_test)

test_prediction <- xgboost_model_final |>
  # fit the model on all the training data
  fit(
    formula = nla_formula,
    data    = train_processed
  ) |>
  # use the training model fit to predict the test data
  predict(new_data = test_processed) |>
  bind_cols(nla_test |>
  mutate(.obs = log10(cyanophyta)))

# measure the accuracy of our model using `yardstick`
xgboost_score <- test_prediction |>
  yardstick::metrics(.obs, .pred) |>
  mutate(.estimate = format(round(.estimate, 2), big.mark = ","))

knitr::kable(xgboost_score)
.metric .estimator .estimate
rmse standard 0.80
rsq standard 0.35
mae standard 0.62

tidymodels - evaluation

cyanophyta_prediction_residual <- test_prediction |>
  arrange(.pred) %>%
  mutate(residual_pct = (.obs - .pred) / .pred) |>
  select(.pred, residual_pct)

cyanophyta_prediction_residual |>
ggplot(aes(x = .pred, y = residual_pct)) +
  geom_point() +
  xlab("Predicted Cyanophyta") +
  ylab("Residual (%)")

tidymodels - test evaluation

test_prediction |>
  ggplot(aes(.pred, .obs)) +
geom_point() +
geom_smooth(method = "lm", colour = "black")

欢迎讨论!

苏命|https://drwater.rcees.ac.cn; https://drwater.rcees.ac.cn/bcard; Slides