2.2. Model sculpting and interpretation - bike

Published

2024-09-05

Note

The following scripts need to be run first before knitting the document corresponding to model_type specified below:

  • 3_model_sculpt_bike.R

Setup and load

Show the code
library(dplyr)
library(stringr)
source(here::here("R/0_setup.R"))

theme_set(theme_bw(base_size = 12))

theme_facets <- theme(
  text = element_text(size = 16),
  legend.position = "inside",
  legend.position.inside = c(0.85, 0.2), 
  legend.background = element_rect(colour = "black"), 
  legend.title = element_blank()
)

theme_single <- theme(text = element_text(size = 16))

# Function to check if the model is trained
load_model_if_trained <- function(model_type) {
  model_path <- file.path(storage_folder, sname(paste0(model_type, "-fit_final.rds")))
  if(file.exists(model_path)) {
    load_results(sname(paste0(model_type, "-fit_final.rds")))
  } else {
    stop(paste0("Model ", model_type, " is not trained for compas, `",
                model_path, "` does not exist. ",
                "Please train it first with: ",
                "`Rscript R/2_train_models.R ", model_type, " compas FALSE`"))
  }
}
Show the code
# set dataset (any with discrete response)
dataset <- "bike"

# set model_type
model_type = "xgb_bayes"
# model_type = "xgb"

# set nr of features for a polished model
top_k <- 7

# util function for storage
sname <- function(x, prefix = dataset) {
  paste0(prefix, "-", x)
}


# load dataset
dd <- define_data(dataset)

# load xgb
xgb <- load_model_if_trained(model_type)
# xgb_fo <- load_model_if_trained("xgb_1_order_bayes")
xgb_fo <- load_model_if_trained("xgb_1_order")

# get product marginals
pm <- sample_marginals(dd$data$train[dd$covariates$all], n = 1e4, seed = 1234)

xgb model

Sculpting

Main models

Show the code
# load sculpted models
rs_pm <- load_results(paste0(dataset, "-", model_type, "-sculpt_rough_pm.rds"))
ds_pm <- load_results(paste0(dataset, "-", model_type, "-sculpt_detailed_pm.rds"))
ps_pm <- load_results(paste0(dataset, "-", model_type, "-sculpt_polished_pm.rds"))

Other sculpting models

Show the code
# get rough model - on train
rs_train <- sculpt_rough(
  dd$data$train[dd$covariates$all],
  seed = 1234,
  model_predict_fun = function(x) {
    predict(xgb, new_data = x)$.pred
  },
  data_as_marginals = TRUE
)

# First order model
rs_pm_xgb_fo <- sculpt_rough(
  pm, 
  seed = 1234,
  model_predict_fun = function(x) {
    predict(xgb_fo, new_data = x)$.pred
  }
)

ICE plots

Show the code
scale_col_update <- 
  scale_color_manual(
    values = c("ICE Profiles" = "gray60", "Rough model (with SE)" = "blue"),
    labels = c("ICE Profiles", "Rough model"),
    name = ""
  )

ice_pm_ceteris <- g_ice(rs_pm, centered = F, show_PDP = F, 
                        facet_spec = facet_specification(ncol = 3))
ice_pm <- g_ice(rs_pm, centered = T, show_PDP = T, 
                facet_spec = facet_specification(ncol = 3))
Show the code
ice_pm_ceteris$continuous + theme_facets

Show the code
ice_pm_ceteris$discrete

Show the code
ice_pm$continuous + scale_col_update + theme_facets

Show the code
ice_pm$discrete + scale_col_update + theme_facets

Show the code
# comparison plot
comp_xgb_bayes <- g_comparison(
  sculptures = list(rs_pm, rs_pm_xgb_fo),
  descriptions = c("Rough Model", "Direct Additive XGB"), 
  facet_spec = facet_specification(ncol = 3)
)

comp_xgb_bayes$continuous + theme_facets
comp_xgb_bayes$discrete

Show the code
# compare detailed and rough
comp_ds <- 
  g_comparison(
    sculptures = list(rs_pm, ds_pm),
    descriptions = c("Rough Model", "Detailed Model"), 
    facet_spec = facet_specification(ncol = 3)
  )

comp_ds$continuous + theme_facets
comp_ds$discrete + theme_facets

Show the code
# compare detailed and rough
comp_ps <- 
  g_comparison(
    sculptures = list(rs_pm, ps_pm),
    descriptions = c("Rough Model", "Polished Model"), 
    facet_spec = facet_specification(ncol = 3)
  )

comp_ps$continuous + theme_facets
comp_ps$discrete + theme_facets

Data density

Show the code
g_density_plots <- g_density_ice_plot_list(ps_pm,
                                           dd$data$train,
                                           var_names = dd$covariates$all,
                                           var_labels = dd$covariates$labels,
                                           task = dd$task)
cov_cont_ps <- intersect(dd$covariates$continuous, names(ps_pm))
cov_disc_ps <- intersect(dd$covariates$discrete, names(ps_pm))

patchwork::wrap_plots(g_density_plots[cov_cont_ps], ncol = 2) 
patchwork::wrap_plots(g_density_plots[cov_disc_ps], ncol = 2) 

Additivity evaluation

Show the code
p1 <- predict(xgb, new_data = pm)$.pred
p2 <- predict(xgb, new_data = dd$data$train)$.pred
p3 <- predict(rs_pm, pm)
p4 <- predict(rs_train, dd$data$train)

g_additivity(
  sp = list(p3, p4),
  lp = list(p1, p2),
  descriptions = c("Product Marginals", "Train Set")
) + 
    labs(x = "Rough Model Predictions", y = "Strong Learner Predictions") + 
    theme_single

Variable importance

Show the code
vi_pm <- g_var_imp(rs_pm, show_pdp_plot = FALSE, textsize = 16, var_imp_type = "ice")
plot(vi_pm)

Show the code
vi_train <- g_var_imp(rs_train, show_pdp_plot = FALSE, textsize = 16, var_imp_type = "ice")
plot(vi_train)

Calibration

Show the code
preds_sculptures <- tibble(
  obs = dd$data$holdout[[dd$response]],
  xgb = predict(xgb, new_data = dd$data$holdout)$.pred,
  rm = predict(rs_pm, newdata = dd$data$holdout),
  pm = predict(ps_pm, newdata = dd$data$holdout),
  dir = predict(xgb_fo, new_data = dd$data$holdout)$.pred
) %>%
  pivot_longer(
    cols = -obs,
    names_to = "Model",
    values_to = "pred"
  ) %>%
  mutate(
    Model = c(
      "xgb" = "XGBoost", "rm" = "Rough Model",
      "pm" = "Polished Model", "dir" = "Direct Additive XGBoost"
    )[Model],
    Model = factor(
      Model, 
      levels = c(
        "XGBoost", "Rough Model", "Polished Model", "Direct Additive XGBoost"
      )
    )
  )

calib_plot_sculptures <- ggplot(preds_sculptures) + 
  geom_smooth(aes(x = pred, y = obs, colour = Model), se = F, method = "gam", formula = y~x) + 
  geom_abline(linetype = "dashed") + 
  labs(x = "Prediction", y = "Truth") + 
  theme_bw() + 
  theme(text = element_text(size = 18))

calib_plot_sculptures

Compare with linear models

Load and sculpt

Show the code
elastic <- load_model_if_trained("lm_elastic")
lasso <- load_model_if_trained("lm_lasso")
ridge <- load_model_if_trained("lm_ridge")

dm_linm <- define_model(type = "lm", data_info = dd)
linm <- fit(dm_linm$workflow, data = dd$data$train)
tg_linm <- fit_resamples(dm_linm$workflow, dd$cv)
Show the code
# sculptures on pm from different models
rs_pm_elastic <- sculpt_rough(
  pm,
  seed = 1234,
  model_predict_fun = function(x) {
    predict(elastic, new_data = x)$.pred
  }
)
rs_pm_lasso <- sculpt_rough(
  pm,
  seed = 1234,
  model_predict_fun = function(x) {
    predict(lasso, new_data = x)$.pred
  }
)
rs_pm_ridge <- sculpt_rough(
  pm,
  seed = 1234,
  model_predict_fun = function(x) {
    predict(ridge, new_data = x)$.pred
  }
)
rs_pm_linm <- sculpt_rough(
  pm,
  seed = 1234,
  model_predict_fun = function(x) {
    predict(linm, new_data = x)$.pred
  }
)

ICE

Show the code
comp_models <- g_comparison(
  sculptures = list(rs_pm_elastic, rs_pm_lasso, rs_pm_ridge, rs_pm_linm, ps_pm),
  descriptions = c("Elastic Net", "Lasso", "Ridge", "Logistic Regression", "Polished"),
  facet_spec = facet_specification(ncol = 3)
)

comp_models$continuous + theme_facets
comp_models$discrete + theme_facets

Calibration

Show the code
preds_models <- tibble(
  obs = dd$data$holdout[[dd$response]],
  xgbPol = predict(ps_pm, newdata = dd$data$holdout),
  linm = predict(rs_pm_linm, newdata = dd$data$holdout),
  elastic = predict(rs_pm_elastic, newdata = dd$data$holdout),
  lasso = predict(rs_pm_lasso, newdata = dd$data$holdout),
  ridge = predict(rs_pm_ridge, newdata = dd$data$holdout)
) %>% 
  pivot_longer(
    cols = -obs,
    names_to = "Model",
    values_to = "pred"
  ) %>%
  mutate(
    Model = c(
      "xgbPol" = "Polished", 
      "linm" = "Logistic", "elastic" = "Elastic Net", "lasso" = "Lasso", "ridge" = "Ridge" 
    )[Model]
  )

# calibration plot on holdout based on pm sculptures of different linear models
calib_plot_models <- ggplot(preds_models) + 
  geom_smooth(aes(x = pred, y = obs, colour = Model), se = F, method = "gam", formula = y~x) + 
  geom_abline(linetype = "dashed") + 
  labs(x = "Prediction", y = "Truth") + 
  theme_bw() + 
  theme(text = element_text(size = 18))

calib_plot_models

Session info

Show the code
devtools::session_info()
─ Session info ───────────────────────────────────────────────────────────────
 setting  value
 version  R version 4.3.3 (2024-02-29)
 os       Ubuntu 22.04.4 LTS
 system   x86_64, linux-gnu
 ui       X11
 language (EN)
 collate  en_US.UTF-8
 ctype    en_US.UTF-8
 tz       Etc/UTC
 date     2024-09-05
 pandoc   3.1.13 @ /opt/conda/bin/ (via rmarkdown)

─ Packages ───────────────────────────────────────────────────────────────────
 package      * version    date (UTC) lib source
 backports      1.4.1      2021-12-13 [2] RSPM (R 4.3.0)
 broom        * 1.0.5      2023-06-09 [2] RSPM (R 4.3.0)
 cachem         1.0.8      2023-05-01 [2] RSPM (R 4.3.0)
 checkmate      2.3.1      2023-12-04 [2] RSPM (R 4.3.0)
 class          7.3-22     2023-05-03 [2] RSPM (R 4.3.0)
 cli            3.6.2      2023-12-11 [2] RSPM (R 4.3.0)
 codetools      0.2-20     2024-03-31 [4] RSPM (R 4.3.3)
 colorspace     2.1-0      2023-01-23 [2] RSPM (R 4.3.0)
 data.table     1.15.4     2024-03-30 [2] RSPM (R 4.3.0)
 devtools       2.4.5      2022-10-11 [1] RSPM (R 4.3.0)
 dials        * 1.2.1      2024-02-22 [1] RSPM (R 4.3.0)
 DiceDesign     1.10       2023-12-07 [1] RSPM (R 4.3.0)
 digest         0.6.35     2024-03-11 [2] RSPM (R 4.3.0)
 dplyr        * 1.1.4      2023-11-17 [2] RSPM (R 4.3.0)
 ellipsis       0.3.2      2021-04-29 [2] RSPM (R 4.3.0)
 evaluate       0.23       2023-11-01 [2] RSPM (R 4.3.0)
 fansi          1.0.6      2023-12-08 [2] RSPM (R 4.3.0)
 farver         2.1.1      2022-07-06 [2] RSPM (R 4.3.0)
 fastmap        1.1.1      2023-02-24 [2] RSPM (R 4.3.0)
 foreach        1.5.2      2022-02-02 [1] RSPM (R 4.3.0)
 fs             1.6.3      2023-07-20 [2] RSPM (R 4.3.0)
 furrr          0.3.1      2022-08-15 [1] RSPM (R 4.3.0)
 future         1.33.2     2024-03-26 [1] RSPM (R 4.3.0)
 future.apply   1.11.2     2024-03-28 [1] RSPM (R 4.3.0)
 generics       0.1.3      2022-07-05 [2] RSPM (R 4.3.0)
 ggplot2      * 3.5.0      2024-02-23 [2] RSPM (R 4.3.0)
 ggrepel        0.9.5      2024-01-10 [1] RSPM (R 4.3.0)
 glmnet         4.1-8      2023-08-22 [1] RSPM (R 4.3.0)
 globals        0.16.3     2024-03-08 [1] RSPM (R 4.3.0)
 glue           1.7.0      2024-01-09 [2] RSPM (R 4.3.0)
 gower          1.0.1      2022-12-22 [1] RSPM (R 4.3.0)
 GPfit          1.0-8      2019-02-08 [1] RSPM (R 4.3.0)
 gridExtra      2.3        2017-09-09 [2] RSPM (R 4.3.0)
 gtable         0.3.4      2023-08-21 [2] RSPM (R 4.3.0)
 hardhat        1.3.1      2024-02-02 [1] RSPM (R 4.3.0)
 here           1.0.1      2020-12-13 [1] RSPM (R 4.3.0)
 htmltools      0.5.8.1    2024-04-04 [2] RSPM (R 4.3.0)
 htmlwidgets    1.6.4      2023-12-06 [2] RSPM (R 4.3.0)
 httpuv         1.6.15     2024-03-26 [2] RSPM (R 4.3.0)
 infer        * 1.0.7      2024-03-25 [1] RSPM (R 4.3.0)
 ipred          0.9-14     2023-03-09 [1] RSPM (R 4.3.0)
 iterators      1.0.14     2022-02-05 [1] RSPM (R 4.3.0)
 jsonlite       1.8.8      2023-12-04 [2] RSPM (R 4.3.0)
 knitr          1.46       2024-04-06 [2] RSPM (R 4.3.0)
 labeling       0.4.3      2023-08-29 [2] RSPM (R 4.3.0)
 later          1.3.2      2023-12-06 [2] RSPM (R 4.3.0)
 lattice        0.22-6     2024-03-20 [4] RSPM (R 4.3.3)
 lava           1.8.0      2024-03-05 [1] RSPM (R 4.3.0)
 lhs            1.1.6      2022-12-17 [1] RSPM (R 4.3.0)
 lifecycle      1.0.4      2023-11-07 [2] RSPM (R 4.3.0)
 listenv        0.9.1      2024-01-29 [1] RSPM (R 4.3.0)
 lubridate      1.9.3      2023-09-27 [2] RSPM (R 4.3.0)
 magrittr       2.0.3      2022-03-30 [2] RSPM (R 4.3.0)
 MASS           7.3-60.0.1 2024-01-13 [4] RSPM (R 4.3.3)
 Matrix         1.6-5      2024-01-11 [4] RSPM (R 4.3.3)
 memoise        2.0.1      2021-11-26 [2] RSPM (R 4.3.0)
 mgcv         * 1.9-1      2023-12-21 [4] RSPM (R 4.3.3)
 mime           0.12       2021-09-28 [2] RSPM (R 4.3.0)
 miniUI         0.1.1.1    2018-05-18 [2] RSPM (R 4.3.0)
 modeldata    * 1.3.0      2024-01-21 [1] RSPM (R 4.3.0)
 modsculpt    * 0.1.1      2024-09-05 [1] Github (Genentech/modsculpt@65bdd78)
 munsell        0.5.1      2024-04-01 [2] RSPM (R 4.3.0)
 nlme         * 3.1-164    2023-11-27 [4] RSPM (R 4.3.3)
 nnet           7.3-19     2023-05-03 [4] RSPM (R 4.3.3)
 parallelly     1.37.1     2024-02-29 [1] RSPM (R 4.3.0)
 parsnip      * 1.2.1      2024-03-22 [1] RSPM (R 4.3.0)
 patchwork      1.2.0      2024-01-08 [1] RSPM (R 4.3.0)
 pillar         1.9.0      2023-03-22 [2] RSPM (R 4.3.0)
 pkgbuild       1.4.4      2024-03-17 [2] RSPM (R 4.3.0)
 pkgconfig      2.0.3      2019-09-22 [2] RSPM (R 4.3.0)
 pkgload        1.3.4      2024-01-16 [2] RSPM (R 4.3.0)
 prodlim        2023.08.28 2023-08-28 [1] RSPM (R 4.3.0)
 profvis        0.3.8      2023-05-02 [1] RSPM (R 4.3.0)
 promises       1.3.0      2024-04-05 [2] RSPM (R 4.3.0)
 purrr        * 1.0.2      2023-08-10 [2] RSPM (R 4.3.0)
 R6             2.5.1      2021-08-19 [2] RSPM (R 4.3.0)
 Rcpp           1.0.12     2024-01-09 [2] RSPM (R 4.3.0)
 recipes      * 1.0.10     2024-02-18 [1] RSPM (R 4.3.0)
 remotes        2.5.0      2024-03-17 [2] RSPM (R 4.3.0)
 rlang          1.1.3      2024-01-10 [2] RSPM (R 4.3.0)
 rmarkdown      2.26       2024-03-05 [2] RSPM (R 4.3.0)
 rpart          4.1.23     2023-12-05 [4] RSPM (R 4.3.3)
 rprojroot      2.0.4      2023-11-05 [2] RSPM (R 4.3.0)
 rsample      * 1.2.1      2024-03-25 [1] RSPM (R 4.3.0)
 rstudioapi     0.16.0     2024-03-24 [2] RSPM (R 4.3.0)
 scales       * 1.3.0      2023-11-28 [2] RSPM (R 4.3.0)
 sessioninfo    1.2.2      2021-12-06 [1] RSPM (R 4.3.0)
 shape          1.4.6.1    2024-02-23 [1] RSPM (R 4.3.0)
 shiny          1.8.1.1    2024-04-02 [2] RSPM (R 4.3.0)
 stats4phc    * 0.1.1      2024-06-20 [1] Github (genentech/stats4phc@e868e23)
 stringi        1.8.3      2023-12-11 [2] RSPM (R 4.3.0)
 stringr      * 1.5.1      2023-11-14 [2] RSPM (R 4.3.0)
 survival       3.5-8      2024-02-14 [4] RSPM (R 4.3.3)
 tibble       * 3.2.1      2023-03-20 [2] RSPM (R 4.3.0)
 tidymodels   * 1.2.0      2024-03-25 [1] RSPM (R 4.3.0)
 tidyr        * 1.3.1      2024-01-24 [2] RSPM (R 4.3.0)
 tidyselect     1.2.1      2024-03-11 [2] RSPM (R 4.3.0)
 timechange     0.3.0      2024-01-18 [2] RSPM (R 4.3.0)
 timeDate       4032.109   2023-12-14 [1] RSPM (R 4.3.0)
 tune         * 1.2.0      2024-03-20 [1] RSPM (R 4.3.0)
 urlchecker     1.0.1      2021-11-30 [1] RSPM (R 4.3.0)
 usethis        2.2.2      2023-07-06 [1] RSPM (R 4.3.0)
 utf8           1.2.4      2023-10-22 [2] RSPM (R 4.3.0)
 vctrs          0.6.5      2023-12-01 [2] RSPM (R 4.3.0)
 viridisLite    0.4.2      2023-05-02 [2] RSPM (R 4.3.0)
 withr          3.0.0      2024-01-16 [2] RSPM (R 4.3.0)
 workflows    * 1.1.4      2024-02-19 [1] RSPM (R 4.3.0)
 workflowsets * 1.1.0      2024-03-21 [1] RSPM (R 4.3.0)
 xfun           0.43       2024-03-25 [4] RSPM (R 4.3.3)
 xgboost        1.7.7.1    2024-01-25 [1] RSPM (R 4.3.0)
 xtable         1.8-4      2019-04-21 [2] RSPM (R 4.3.0)
 yaml           2.3.8      2023-12-11 [2] RSPM (R 4.3.0)
 yardstick    * 1.3.1      2024-03-21 [1] RSPM (R 4.3.0)

 [1] /home/yoshidk6/R/x86_64-pc-linux-gnu-library/4.3
 [2] /usr/local/lib/R/site-library
 [3] /usr/lib/R/site-library
 [4] /usr/lib/R/library

──────────────────────────────────────────────────────────────────────────────