Cross validation

STA 210 - Spring 2022

Dr. Mine Çetinkaya-Rundel

Welcome

Topics

  • Cross validation for model evaluation
  • Cross validation for model comparison

Computational setup

# load packages
library(tidyverse)
library(tidymodels)
library(knitr)
library(schrute)

Data & goal

  • Data: The data come from the shrute package, and has been transformed using instructions from Lab 4
  • Goal: Predict imdb_rating from other variables in the dataset
office_episodes <- read_csv(here::here("slides", "data/office_episodes.csv"))
office_episodes
# A tibble: 186 × 14
   season episode episode_name      imdb_rating total_votes air_date   lines_jim
    <dbl>   <dbl> <chr>                   <dbl>       <dbl> <date>         <dbl>
 1      1       1 Pilot                     7.6        3706 2005-03-24    0.157 
 2      1       2 Diversity Day             8.3        3566 2005-03-29    0.123 
 3      1       3 Health Care               7.9        2983 2005-04-05    0.172 
 4      1       4 The Alliance              8.1        2886 2005-04-12    0.202 
 5      1       5 Basketball                8.4        3179 2005-04-19    0.0913
 6      1       6 Hot Girl                  7.8        2852 2005-04-26    0.159 
 7      2       1 The Dundies               8.7        3213 2005-09-20    0.125 
 8      2       2 Sexual Harassment         8.2        2736 2005-09-27    0.0565
 9      2       3 Office Olympics           8.4        2742 2005-10-04    0.196 
10      2       4 The Fire                  8.4        2713 2005-10-11    0.160 
# … with 176 more rows, and 7 more variables: lines_pam <dbl>,
#   lines_michael <dbl>, lines_dwight <dbl>, halloween <chr>, valentine <chr>,
#   christmas <chr>, michael <chr>

Modeling prep

Split data into training and testing

set.seed(123)
office_split <- initial_split(office_episodes)
office_train <- training(office_split)
office_test <- testing(office_split)

Specify model

office_spec <- linear_reg() %>%
  set_engine("lm")

office_spec
Linear Regression Model Specification (regression)

Computational engine: lm 

Model 1

From yesterday’s lab

  • Create a recipe that uses the new variables we generated
  • Denotes episode_name as an ID variable and doesn’t use air_date as a predictor
  • Create dummy variables for all nominal predictors
  • Remove all zero variance predictors

Create recipe

office_rec1 <- recipe(imdb_rating ~ ., data = office_train) %>%
  update_role(episode_name, new_role = "id") %>%
  step_rm(air_date) %>%
  step_dummy(all_nominal_predictors()) %>%
  step_zv(all_predictors())

office_rec1
Recipe

Inputs:

      role #variables
        id          1
   outcome          1
 predictor         12

Operations:

Delete terms air_date
Dummy variables from all_nominal_predictors()
Zero variance filter on all_predictors()

Create workflow

office_wflow1 <- workflow() %>%
  add_model(office_spec) %>%
  add_recipe(office_rec1)

office_wflow1
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()

── Preprocessor ────────────────────────────────────────────────────────────────
3 Recipe Steps

• step_rm()
• step_dummy()
• step_zv()

── Model ───────────────────────────────────────────────────────────────────────
Linear Regression Model Specification (regression)

Computational engine: lm 

Fit model to training data

Actually, not so fast!

Cross validation

Spending our data

  • We have already established that the idea of data spending where the test set was recommended for obtaining an unbiased estimate of performance.
  • However, we usually need to understand the effectiveness of the model before using the test set.
  • Typically we can’t decide on which final model to take to the test set without making model assessments.
  • Remedy: Resampling to make model assessments on training data in a way that can generalize to new data.

Resampling for model assessment

Resampling is only conducted on the training set. The test set is not involved. For each iteration of resampling, the data are partitioned into two subsamples:

  • The model is fit with the analysis set.
  • The model is evaluated with the assessment set.

Resampling for model assessment


Source: Kuhn and Silge. Tidy modeling with R.

Analysis and assessment sets

  • Analysis set is analogous to training set.
  • Assessment set is analogous to test set.
  • The terms analysis and assessment avoids confusion with initial split of the data.
  • These data sets are mutually exclusive.

Cross validation

More specifically, v-fold cross validation – commonly used resampling technique:

  • Randomly split your training data into v partitions
  • Use 1 partition for assessment, and the remaining v-1 partitions for analysis
  • Repeat v times, updating which partition is used for assessment each time

Let’s give an example where v = 3

Cross validation, step 1

Randomly split your training data into 3 partitions:


Split data

set.seed(345)
folds <- vfold_cv(office_train, v = 3)
folds
#  3-fold cross-validation 
# A tibble: 3 × 2
  splits          id   
  <list>          <chr>
1 <split [92/47]> Fold1
2 <split [93/46]> Fold2
3 <split [93/46]> Fold3

Cross validation, steps 2 and 3

  • Use 1 partition for assessment, and the remaining v-1 partitions for analysis
  • Repeat v times, updating which partition is used for assessment each time

Fit resamples

set.seed(456)

office_fit_rs1 <- office_wflow1 %>%
  fit_resamples(folds)

office_fit_rs1
# Resampling results
# 3-fold cross-validation 
# A tibble: 3 × 4
  splits          id    .metrics         .notes          
  <list>          <chr> <list>           <list>          
1 <split [92/47]> Fold1 <tibble [2 × 4]> <tibble [0 × 1]>
2 <split [93/46]> Fold2 <tibble [2 × 4]> <tibble [0 × 1]>
3 <split [93/46]> Fold3 <tibble [2 × 4]> <tibble [0 × 1]>

Cross validation, now what?

  • We’ve fit a bunch of models
  • Now it’s time to use them to collect metrics (e.g., R-squared, RMSE) on each model and use them to evaluate model fit and how it varies across folds

Collect CV metrics

collect_metrics(office_fit_rs1)
# A tibble: 2 × 6
  .metric .estimator  mean     n std_err .config             
  <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
1 rmse    standard   0.351     3  0.0111 Preprocessor1_Model1
2 rsq     standard   0.546     3  0.0378 Preprocessor1_Model1

Deeper look into CV metrics

cv_metrics1 <- collect_metrics(office_fit_rs1, summarize = FALSE) 

cv_metrics1
# A tibble: 6 × 5
  id    .metric .estimator .estimate .config             
  <chr> <chr>   <chr>          <dbl> <chr>               
1 Fold1 rmse    standard       0.356 Preprocessor1_Model1
2 Fold1 rsq     standard       0.520 Preprocessor1_Model1
3 Fold2 rmse    standard       0.367 Preprocessor1_Model1
4 Fold2 rsq     standard       0.498 Preprocessor1_Model1
5 Fold3 rmse    standard       0.330 Preprocessor1_Model1
6 Fold3 rsq     standard       0.621 Preprocessor1_Model1

Better tabulation of CV metrics

cv_metrics1 %>%
  mutate(.estimate = round(.estimate, 3)) %>%
  pivot_wider(id_cols = id, names_from = .metric, values_from = .estimate) %>%
  kable(col.names = c("Fold", "RMSE", "R-squared"))
Fold RMSE R-squared
Fold1 0.356 0.520
Fold2 0.367 0.498
Fold3 0.330 0.621

How does RMSE compare to y?

Cross validation RMSE stats:

cv_metrics1 %>%
  filter(.metric == "rmse") %>%
  summarise(
    min = min(.estimate),
    max = max(.estimate),
    mean = mean(.estimate),
    sd = sd(.estimate)
  )
# A tibble: 1 × 4
    min   max  mean     sd
  <dbl> <dbl> <dbl>  <dbl>
1 0.330 0.367 0.351 0.0192

Training data IMDB score stats:

office_episodes %>%
  summarise(
    min = min(imdb_rating),
    max = max(imdb_rating),
    mean = mean(imdb_rating),
    sd = sd(imdb_rating)
  )
# A tibble: 1 × 4
    min   max  mean    sd
  <dbl> <dbl> <dbl> <dbl>
1   6.7   9.7  8.25 0.535

Cross validation jargon

  • Referred to as v-fold or k-fold cross validation
  • Also commonly abbreviated as CV

Cross validation, for reals

  • To illustrate how CV works, we used v = 3:

    • Analysis sets are 2/3 of the training set
    • Each assessment set is a distinct 1/3
    • The final resampling estimate of performance averages each of the 3 replicates
  • This was useful for illustrative purposes, but v = 3 is a poor choice in practice

  • Values of v are most often 5 or 10; we generally prefer 10-fold cross-validation as a default

Application exercise

Recap

  • Cross validation for model evaluation
  • Cross validation for model comparison