class: center, middle, inverse, title-slide # Cross-Validation ### K Arnold ### DATA 202 Fall 2020 --- .small-code[ ] ## Q&A > Was `model3` bad b/c it had a big difference between train and test? Great question! The model *itself* was not bad, but we *over-sold* it. > When is decision tree better than linreg? Worse? * Decision tree: crisp regions are easy; smooth variation is hard. * Linear regression: smooth variation is easy; crisp regions are hard. > Could a linear regression use sinusoidal models? Yes! There's lots of "basis functions" we can use (splines, sinusoids, rectifiers, ...) --- ## Q&A 2 > Can a tree check multiple conditions at once? No, though it can check them in sequence. If you know that's important, you can create new features. > Is there a limit to num decisions? Yes: each decision must be *justified by the data*. > Can greedy algorithms be worse? Yes. Seeking short-term gain gives long-term regret. --- ## Quick datavis tips * Using data from Pew? See [notes on *weights*](https://cs.calvin.edu/courses/data/202/fa20/spss-tips.html) * Don't forget the cheat sheets! [`dplyr`](https://raw.githubusercontent.com/rstudio/cheatsheets/master/data-transformation.pdf), [`ggplot` ](https://raw.githubusercontent.com/rstudio/cheatsheets/master/data-visualization-2.1.pdf), ... * I learned some things glancing over them again! (`coord_cartesian` instead of `xlim` etc.) * You can also use plotly, seaborn, etc., as long as it's reproducible. * I'm not watching your every commit. If you need help, send a quick screenshot over Teams (preferably Q&A channel) --- class: center, middle # Cross-Validation --- ## Why Cross-Validation? Measure accuracy on unseen data *without peeking at test set* (compare lab9) <img src="w10d2-cross-validation_files/figure-html/compare-models-traintest-1.png" width="100%" style="display: block; margin: auto;" /> --- ## What is Cross-Validation? <img src="https://www.tmwr.org/premade/resampling.svg" width="90%" style="display: block; margin: auto;" /> --- <img src="https://www.tmwr.org/premade/three-CV-iter.svg" width="100%" style="display: block; margin: auto;" /> --- <img src="w10d2-cross-validation_files/figure-html/ames-cv-anim-.gif" width="100%" style="display: block; margin: auto;" /> --- ## How to do CV? 1. Declare the splitting strategy: ```r ames_resamples <- ames_train %>% vfold_cv(v = 10) ``` ```r ames_resamples ``` ``` ## # 10-fold cross-validation ## # A tibble: 10 x 2 ## splits id ## <list> <chr> ## 1 <split [1.4K/161]> Fold01 ## 2 <split [1.4K/161]> Fold02 ## 3 <split [1.4K/161]> Fold03 ## 4 <split [1.4K/161]> Fold04 ## 5 <split [1.4K/161]> Fold05 ## 6 <split [1.4K/161]> Fold06 ## # … with 4 more rows ``` --- ## How to do CV? 1. Declare the splitting strategy 2. Fit on each resample, evaluate using a set of metrics. <img src="w10d2-cross-validation_files/figure-html/ames-cv-model3-anim-.gif" width="100%" style="display: block; margin: auto;" /> --- ## How to do CV? 1. Declare the splitting strategy 2. Fit on each resample, evaluate using a set of metrics. ```r model3_samples <- model3_spec %>% * fit_resamples( Sale_Price ~ Latitude + Longitude, * resamples = ames_resamples, metrics = metric_set(mae)) model3_samples %>% collect_metrics(summarize = FALSE) ``` ``` ## # A tibble: 10 x 4 ## id .metric .estimator .estimate ## <chr> <chr> <chr> <dbl> ## 1 Fold01 mae standard 34.2 ## 2 Fold02 mae standard 31.1 ## 3 Fold03 mae standard 29.0 ## 4 Fold04 mae standard 33.9 ## 5 Fold05 mae standard 29.1 ## 6 Fold06 mae standard 36.1 ## # … with 4 more rows ``` --- ## How to do CV? 1. Declare the splitting strategy 2. Fit on each resample, evaluate using a set of metrics. 3. Plot and/or summarize the metrics. .pull-left[ ```r model3_samples %>% collect_metrics(summarize = FALSE) %>% ggplot(aes(x = .estimate, y = "model3")) + geom_point() ``` <img src="w10d2-cross-validation_files/figure-html/crude-plot-folds-1.png" width="100%" style="display: block; margin: auto;" /> ] .pull-right[ ```r model3_samples %>% collect_metrics(summarize = TRUE) ``` ``` ## # A tibble: 1 x 5 ## .metric .estimator mean n std_err ## <chr> <chr> <dbl> <int> <dbl> ## 1 mae standard 32.7 10 0.751 ``` ] --- ## A tidy way to compare models 1. Make a data frame of model specs: ```r all_models <- tribble( ~model_name, ~spec, "model1", decision_tree(mode = "regression", tree_depth = 2), "model2", decision_tree(mode = "regression", tree_depth = 30), "model3", decision_tree(mode = "regression", cost_complexity = 1e-6, min_n = 2) ) ``` --- ## A tidy way to compare models 1. Make a data frame of model specs: 2. Sample each model (using `dplyr::rowwise`): ```r models_with_samples <- all_models %>% rowwise() %>% mutate(samples = list( spec %>% fit_resamples( Sale_Price ~ Latitude + Longitude, resamples = ames_resamples, metrics = metric_set(mae)))) ``` ```r models_with_samples ``` ``` ## # A tibble: 3 x 3 ## # Rowwise: ## model_name spec samples ## <chr> <list> <list> ## 1 model1 <spec[+]> <tibble [10 × 4]> ## 2 model2 <spec[+]> <tibble [10 × 4]> ## 3 model3 <spec[+]> <tibble [10 × 4]> ``` --- ```r models_with_samples %>% * rowwise(model_name) %>% * summarize(collect_metrics(samples, summarize = FALSE)) %>% ggplot(aes(x = model_name, y = .estimate)) + geom_boxplot() + labs(x = "Model name", y = "Mean absolute error ($1000)") + coord_cartesian(ylim = c(0, NA)) ``` <img src="w10d2-cross-validation_files/figure-html/unnamed-chunk-4-1.png" width="100%" style="display: block; margin: auto;" /> --- ## Appendix: code .small-code[ ```r add_predictions <- function(data, ...) { imap_dfr( rlang::dots_list(..., .named = TRUE), function(model, model_name) { model %>% predict(data) %>% bind_cols(data) %>% mutate(model = !!model_name) } ) } sweep_model <- function(model, var_to_sweep, sweep_min, sweep_max, ...) { X <- expand_grid(!!enquo(var_to_sweep) := seq(sweep_min, sweep_max, length.out = 500), ...) model %>% predict(X) %>% bind_cols(X) } linear_reg <- function(engine = "lm", ...) { parsnip::linear_reg(...) %>% set_engine(engine) } decision_tree <- function(mode = "regression", engine = "rpart", ...) { parsnip::decision_tree(mode = "regression", ...) %>% set_engine(engine) } data(ames, package = "modeldata") ames_all <- ames %>% filter(Gr_Liv_Area < 4000, Sale_Condition == "Normal") %>% mutate(across(where(is.integer), as.double)) %>% mutate(Sale_Price = Sale_Price / 1000) rm(ames) set.seed(10) # Seed the random number generator ames_split <- initial_split(ames_all, prop = 2 / 3) ames_train <- training(ames_split) ames_test <- testing(ames_split) lat_long_grid <- expand_grid( Latitude = modelr::seq_range(ames_train$Latitude, n = 200, expand = .05), Longitude = modelr::seq_range(ames_train$Longitude, n = 200, expand = .05), ) show_latlong_model <- function(dataset, model, model_name = deparse(substitute(model))) { ggplot(dataset, aes(x = Longitude, y = Latitude)) + geom_raster( data = lat_long_grid %>% add_predictions(model), mapping = aes(fill = .pred) ) + geom_point(aes(color = Sale_Price), size = .5) + scale_color_viridis_c(aesthetics = c("color", "fill")) + coord_equal() + labs(title = model_name) } ames_resamples <- ames_train %>% vfold_cv(v = 10) all_models <- tribble( ~model_name, ~spec, "model1", decision_tree(mode = "regression", tree_depth = 2), "model2", decision_tree(mode = "regression", tree_depth = 30), "model3", decision_tree(mode = "regression", cost_complexity = 1e-6, min_n = 2) ) models_with_samples <- all_models %>% rowwise() %>% mutate(samples = list( spec %>% fit_resamples( Sale_Price ~ Latitude + Longitude, resamples = ames_resamples, metrics = metric_set(mae)))) model3_spec <- all_models$spec[[3]] test_predictions <- all_models %>% rowwise(model_name) %>% # Fit on all training data mutate(fit_on_all_training_data = list(spec %>% fit(Sale_Price ~ Latitude + Longitude, data = ames_train))) %>% # Test on test set summarize(ames_test %>% add_predictions(fit_on_all_training_data) %>% mae(truth = Sale_Price, estimate = .pred)) models_with_samples %>% rowwise(model_name) %>% summarize(collect_metrics(samples, summarize = FALSE)) %>% bind_rows( train = ., test = test_predictions, .id = "assessment_data" ) %>% mutate(assessment_data = as_factor(assessment_data)) %>% ggplot(aes(x = model_name, y = .estimate, color = assessment_data)) + geom_boxplot() + labs(x = "Model name", y = "Mean absolute error ($1000)", fill = "Assessment dataset") + coord_cartesian(ylim = c(0, NA)) knitr::include_graphics("https://www.tmwr.org/premade/resampling.svg") knitr::include_graphics("https://www.tmwr.org/premade/three-CV-iter.svg") ames_train %>% vfold_cv(v = 10) %>% pull(splits) %>% iwalk(function(split, split_idx) { print( bind_rows(analysis = analysis(split), assessment = assessment(split), .id = "role") %>% ggplot(aes(x = Latitude, y = Longitude, color = role, shape = role)) + geom_point(size = 1) + scale_color_manual(values = c("analysis" = "grey", "assessment" = "red")) + labs(title = glue("Fold {split_idx}")) + theme_bw() + theme(panel.grid = element_blank()) ) }) withr::with_seed(0, { ames_train %>% vfold_cv(v = 10) %>% pull(splits) %>% iwalk(function(split, split_idx) { model <- model3_spec %>% fit(Sale_Price ~ Latitude + Longitude, data = analysis(split)) assess_mae <- assessment(split) %>% add_predictions(model) %>% mae(truth = Sale_Price, estimate = .pred) %>% pull(.estimate) print( bind_rows(analysis = analysis(split), assessment = assessment(split), .id = "role") %>% ggplot(aes(x = Latitude, y = Longitude, color = role)) + geom_raster( data = lat_long_grid %>% add_predictions(model), mapping = aes(fill = .pred, color = NULL) ) + geom_point(size = .1) + scale_color_manual(values = c("analysis" = "grey", "assessment" = "red")) + labs(title = glue("Fold {split_idx} (MAE on assessment = {format(assess_mae, format='f', digits = 4)})")) + theme_bw() + theme(panel.grid = element_blank()) ) }) }) ``` ]