| 1 |
# product marginals -------- |
|
| 2 | ||
| 3 |
#' Sample product marginals dataset |
|
| 4 |
#' |
|
| 5 |
#' @param dat Data.frame to sample from, must include only covariates. |
|
| 6 |
#' @param n Number of observations to sample. |
|
| 7 |
#' @param seed `NULL` or seed for exact reproducibility. |
|
| 8 |
#' |
|
| 9 |
#' @details The product marginals dataset is a grid of values that is sampled independently |
|
| 10 |
#' per each column (feature) from the original dataset. |
|
| 11 |
#' The aim here is to disentangle the correlations between features and assess |
|
| 12 |
#' how each feature affects the model predictions individually. |
|
| 13 |
#' It will not contain new values per column, but it may contain new combinations of values not |
|
| 14 |
#' seen in the original data. |
|
| 15 |
#' One can also check how the model behaves if there are unseen observations |
|
| 16 |
#' (new combination of features). |
|
| 17 |
#’ |
|
| 18 |
#' Note that the use of the product marginal dataset for model sculpting only works |
|
| 19 |
#' if the features are approximately additive for model predictions. |
|
| 20 |
#' In the quite rare case when they are not, the sculpted models using the product marginal |
|
| 21 |
#' dataset is expected to have significantly lower performance and |
|
| 22 |
#' the conclusions may be misleading. |
|
| 23 |
#' |
|
| 24 |
#' One can also try using the original data instead of the product marginals for model |
|
| 25 |
#' sculpting and see how the results differ. |
|
| 26 |
#' |
|
| 27 |
#' @return `data.frame` with same number of columns and `n` rows. |
|
| 28 |
#' @export |
|
| 29 |
#' |
|
| 30 |
#' @examples |
|
| 31 |
#' sample_marginals(mtcars, n = 5, seed = 543) |
|
| 32 |
sample_marginals <- function(dat, n, seed = NULL) {
|
|
| 33 | 70x |
checkmate::assert_data_frame(dat, any.missing = FALSE) |
| 34 | 70x |
checkmate::assert_integerish(n, lower = 1, any.missing = FALSE, len = 1) |
| 35 | ||
| 36 | 70x |
dat <- as.data.frame(dat) |
| 37 | 70x |
cols <- colnames(dat) |
| 38 | 70x |
stopifnot(ncol(dat) > 0, nrow(dat) > 0) |
| 39 | ||
| 40 |
# indexes: random samples of length n, individual per column |
|
| 41 | 70x |
set.seed(seed) |
| 42 | 70x |
idx_per_cols <- lapply( |
| 43 | 70x |
seq_along(cols), |
| 44 | 70x |
function(...) sample.int(nrow(dat), size = n, replace = TRUE) |
| 45 |
) |
|
| 46 | ||
| 47 |
# get values for the indexes above |
|
| 48 | 70x |
dat_sub <- lapply( |
| 49 | 70x |
seq_along(cols), |
| 50 | 70x |
function(i) dat[idx_per_cols[[i]], cols[i], drop = FALSE] |
| 51 |
) |
|
| 52 | 70x |
dat_sub <- do.call("cbind", c(dat_sub, list(row.names = NULL)))
|
| 53 | ||
| 54 |
# if this function is used to generate product marginals or data for ice curves |
|
| 55 | 70x |
return(dat_sub) |
| 56 |
} |
|
| 57 | ||
| 58 | ||
| 59 | ||
| 60 |
# ICE data -------- |
|
| 61 | ||
| 62 |
# calculate ICE data by using product marginals and prediction function |
|
| 63 |
calculate_ice_data <- function(sub, predict_fun, x, x_name, col_order) {
|
|
| 64 | 61x |
stopifnot( |
| 65 | 61x |
is.data.frame(sub) | is.null(sub), |
| 66 | 61x |
is.function(predict_fun), |
| 67 | 61x |
is.atomic(x), |
| 68 | 61x |
is.character(x_name), |
| 69 | 61x |
is.character(col_order) |
| 70 |
) |
|
| 71 | ||
| 72 |
# https://cran.r-project.org/web/packages/data.table/vignettes/datatable-importing.html#globals |
|
| 73 | 61x |
. <- ice <- line_id <- ..x <- ice_centered <- NULL # due to NSE notes in R CMD check |
| 74 | ||
| 75 |
# special case: sculpting performed on 1 variable |
|
| 76 | 61x |
if (is.null(sub)) {
|
| 77 | 2x |
preds <- predict_fun(structure(data.frame(x), names = x_name)) |
| 78 | 2x |
out <- data.table( |
| 79 | 2x |
x = x, |
| 80 | 2x |
ice = preds, |
| 81 | 2x |
ice_centered = preds - mean(preds), |
| 82 | 2x |
line_id = 1 |
| 83 |
) |
|
| 84 | ||
| 85 |
# all other cases |
|
| 86 |
} else {
|
|
| 87 | 59x |
stopifnot(!x_name %in% colnames(sub)) |
| 88 | 59x |
out <- rbindlist( |
| 89 | 59x |
lapply(1:nrow(sub), function(i) cbind(x, sub[i, , drop = FALSE], row.names = NULL)) |
| 90 |
) |
|
| 91 | 59x |
setnames(out, "x", x_name) |
| 92 | 59x |
out[, ice := predict_fun(as.data.frame(out)[, col_order])] |
| 93 | 59x |
out[, line_id := rep(1:nrow(sub), each = length(..x))] |
| 94 | 59x |
out[, ice_centered := ice - mean(ice), line_id] |
| 95 | 59x |
setnames(out, x_name, "x") |
| 96 | 59x |
out <- out[, c("x", "ice", "ice_centered", "line_id")]
|
| 97 |
} |
|
| 98 | 61x |
return(out) |
| 99 |
} |
|
| 100 | ||
| 101 |
# generate ICE data from stored ICE predictions |
|
| 102 |
# the result is similar shape as the returned object from calculate_ice_data |
|
| 103 |
generate_ice_data <- function(predictions, x, logodds_to_prob = FALSE) {
|
|
| 104 | 2x |
stopifnot( |
| 105 | 2x |
is.list(predictions), |
| 106 | 2x |
is.atomic(x) |
| 107 |
) |
|
| 108 | ||
| 109 |
# https://cran.r-project.org/web/packages/data.table/vignettes/datatable-importing.html#globals |
|
| 110 | 2x |
..x <- line_id <- NULL # due to NSE notes in R CMD check |
| 111 | ||
| 112 | 2x |
out <- rbindlist( |
| 113 | 2x |
lapply( |
| 114 | 2x |
predictions, |
| 115 | 2x |
function(p) {
|
| 116 | 10x |
data.frame( |
| 117 | 10x |
x = x, |
| 118 | 10x |
y = `if`(logodds_to_prob, inv.logit(p), p), |
| 119 | 10x |
row.names = NULL |
| 120 |
) |
|
| 121 |
} |
|
| 122 |
) |
|
| 123 |
) |
|
| 124 | 2x |
out[, line_id := rep(seq_along(predictions), each = length(..x))] |
| 125 | ||
| 126 | 2x |
return(out) |
| 127 |
} |
|
| 128 | ||
| 129 | ||
| 130 |
# PDP data ----------- |
|
| 131 | ||
| 132 |
# calculate PDP data from ICE data |
|
| 133 |
calculate_pdp_data <- function(id) {
|
|
| 134 | 59x |
stopifnot(is.data.table(id)) |
| 135 | ||
| 136 |
# https://cran.r-project.org/web/packages/data.table/vignettes/datatable-importing.html#globals |
|
| 137 | 59x |
. <- x <- ice_centered <- NULL # due to NSE notes in R CMD check |
| 138 | ||
| 139 | 59x |
unique(id)[ |
| 140 |
, |
|
| 141 |
.( |
|
| 142 | 59x |
pdp_centered = mean(ice_centered), |
| 143 | 59x |
pdp_centered_se = sd(ice_centered, na.rm = F) / sqrt(.N) |
| 144 |
), |
|
| 145 | 59x |
.(x) |
| 146 |
] |
|
| 147 |
} |
|
| 148 | ||
| 149 |
# generate PDP data from stored ICE predictions |
|
| 150 |
generate_pdp_data <- function(predictions, x, logodds_to_prob = FALSE) {
|
|
| 151 | 1x |
id <- generate_ice_data(predictions = predictions, x = x, logodds_to_prob = logodds_to_prob) |
| 152 | 1x |
setnames(id, old = "y", new = "ice_centered") |
| 153 | 1x |
pd <- calculate_pdp_data(id) |
| 154 | 1x |
setnames(pd, old = "pdp_centered", new = "y") |
| 155 | 1x |
setnames(pd, old = "pdp_centered_se", new = "y_se") |
| 156 | 1x |
return(pd) |
| 157 |
} |
|
| 158 | ||
| 159 | ||
| 160 |
# rough sculpture ---------- |
|
| 161 | ||
| 162 |
check_data <- function(dat) {
|
|
| 163 | 11x |
checkmate::assert_data_frame(dat, any.missing = FALSE) |
| 164 | 11x |
return(as.data.frame(dat)) # remove tbl_df etc |
| 165 |
} |
|
| 166 | ||
| 167 |
check_upf <- function(upf, dat) {
|
|
| 168 | 11x |
checkmate::assert_function(upf, nargs = 1) |
| 169 | 11x |
upf_output <- upf(dat) |
| 170 | 11x |
check_upf_output(dat = dat, output = upf_output) |
| 171 | 11x |
return(upf_output) |
| 172 |
} |
|
| 173 | ||
| 174 |
check_upf_output <- function(dat, output) {
|
|
| 175 | 11x |
checkmate::assert( |
| 176 | 11x |
checkmate::check_numeric(output, finite = TRUE, any.missing = FALSE, len = nrow(dat)), |
| 177 | 11x |
checkmate::check_factor(output, any.missing = FALSE, len = nrow(dat)) |
| 178 |
) |
|
| 179 | 11x |
return(invisible(NULL)) |
| 180 |
} |
|
| 181 | ||
| 182 | ||
| 183 |
#' Create a rough model |
|
| 184 |
#' |
|
| 185 |
#' @param dat Data to create the rough model from. |
|
| 186 |
#' Must be a product marginal dataset (see `sample_marginals`) |
|
| 187 |
#' with covariates only (i.e. without response). |
|
| 188 |
#' @param model_predict_fun Function that returns predictions given a dataset. |
|
| 189 |
#' @param n_ice Number of ICE curves to generate. Defaults to 10. |
|
| 190 |
#' @param seed (`NULL`) or seed for exact reproducibility. |
|
| 191 |
#' @param verbose (`integer`) 0 for silent run, > 0 for messages. |
|
| 192 |
#' @param allow_par (`logical`) Allow parallel computation? Defaults to `FALSE`. |
|
| 193 |
#' @param model_predict_fun_export For parallel computation only. |
|
| 194 |
#' If there is a parallel backend registered (see `parallel_set()`), |
|
| 195 |
#' then use this to export variables used in `model_predict_fun` (like model). |
|
| 196 |
#' This is passed to `foreach::foreach(..., .export = model_predict_fun_export)`. |
|
| 197 |
#' @param data_as_marginals (`logical`) Use the provided data `dat` as already sampled dataset? |
|
| 198 |
#' Defaults to `FALSE`. |
|
| 199 |
#' |
|
| 200 |
#' @details For parallel computation, use [parallel_set()] and set `allow_par` to `TRUE`. |
|
| 201 |
#' Note that parallel computation may fail if the model is too big and there is not enough memory. |
|
| 202 |
#' |
|
| 203 |
#' @return Object of classes `rough` and `sculpture`. |
|
| 204 |
#' @export |
|
| 205 |
#' |
|
| 206 |
#' @examples |
|
| 207 |
#' df <- mtcars |
|
| 208 |
#' df$vs <- as.factor(df$vs) |
|
| 209 |
#' model <- rpart::rpart( |
|
| 210 |
#' hp ~ mpg + carb + vs, |
|
| 211 |
#' data = df, |
|
| 212 |
#' control = rpart::rpart.control(minsplit = 10) |
|
| 213 |
#' ) |
|
| 214 |
#' model_predict <- function(x) predict(model, newdata = x) |
|
| 215 |
#' covariates <- c("mpg", "carb", "vs")
|
|
| 216 |
#' pm <- sample_marginals(df[covariates], n = 50, seed = 5) |
|
| 217 |
#' |
|
| 218 |
#' rs <- sculpt_rough( |
|
| 219 |
#' dat = pm, |
|
| 220 |
#' model_predict_fun = model_predict, |
|
| 221 |
#' n_ice = 10, |
|
| 222 |
#' seed = 1, |
|
| 223 |
#' verbose = 0 |
|
| 224 |
#' ) |
|
| 225 |
#' |
|
| 226 |
#' class(rs) |
|
| 227 |
#' head(predict(rs)) |
|
| 228 |
#' |
|
| 229 |
#' # lm model without interaction -> additive -> same predictions |
|
| 230 |
#' model <- lm(hp ~ mpg + carb + vs, data = df) |
|
| 231 |
#' model_predict <- function(x) predict(model, newdata = x) |
|
| 232 |
#' covariates <- c("mpg", "carb", "vs")
|
|
| 233 |
#' pm <- sample_marginals(df[covariates], n = 50, seed = 5) |
|
| 234 |
#' |
|
| 235 |
#' rs <- sculpt_rough( |
|
| 236 |
#' dat = pm, |
|
| 237 |
#' model_predict_fun = model_predict, |
|
| 238 |
#' n_ice = 10, |
|
| 239 |
#' seed = 1, |
|
| 240 |
#' verbose = 0 |
|
| 241 |
#' ) |
|
| 242 |
#' |
|
| 243 |
#' class(rs) |
|
| 244 |
#' head(predict(rs)) |
|
| 245 |
#' head(predict(model, pm)) |
|
| 246 |
#' |
|
| 247 |
sculpt_rough <- function(dat, model_predict_fun, n_ice = 10, |
|
| 248 |
seed = NULL, verbose = 0, |
|
| 249 |
allow_par = FALSE, |
|
| 250 |
model_predict_fun_export = NULL, |
|
| 251 |
data_as_marginals = FALSE) {
|
|
| 252 | 11x |
dat <- check_data(dat) |
| 253 | 11x |
predictions <- check_upf(model_predict_fun, dat) |
| 254 | 11x |
checkmate::assert_integerish(n_ice, any.missing = FALSE, len = 1) |
| 255 | 11x |
checkmate::assert_integerish(verbose, lower = 0, any.missing = FALSE, len = 1) |
| 256 | 11x |
checkmate::assert_flag(allow_par) |
| 257 | 11x |
checkmate::assert_flag(data_as_marginals) |
| 258 | ||
| 259 | 11x |
covariates <- colnames(dat) |
| 260 | ||
| 261 | 11x |
`%operand%` <- define_foreach_operand(allow_par = allow_par) |
| 262 | 11x |
res <- foreach::foreach(col = covariates, .export = model_predict_fun_export) %operand% {
|
| 263 |
# verbosity ... |
|
| 264 | 56x |
matched <- match(col, covariates) |
| 265 | 56x |
if (verbose > 0) {
|
| 266 | ! |
if (length(covariates) < 10) {
|
| 267 | ! |
message(paste("Sculpting variable:", matched, "/", length(covariates)))
|
| 268 |
} else {
|
|
| 269 | ! |
if ((matched == 1) | (matched %% 10 == 0)) {
|
| 270 | ! |
message(paste("Sculpting variable:", matched, "/", length(covariates)))
|
| 271 |
} |
|
| 272 |
} |
|
| 273 |
} |
|
| 274 | ||
| 275 |
# generate product marginals |
|
| 276 | 56x |
if (data_as_marginals) {
|
| 277 | ! |
dat_subs <- dat[covariates[covariates != col]] |
| 278 |
# Sample n_ice rows |
|
| 279 | ! |
dat_subs <- dat_subs[sample(nrow(dat_subs), n_ice, replace = TRUE), ] |
| 280 | ||
| 281 |
} else {
|
|
| 282 | 56x |
if (length(covariates) > 1) {
|
| 283 | 54x |
dat_subs <- sample_marginals( |
| 284 | 54x |
dat = dat[setdiff(covariates, col)], |
| 285 | 54x |
n = n_ice, |
| 286 | 54x |
seed = seed |
| 287 |
) |
|
| 288 | 54x |
stopifnot(nrow(dat_subs) == n_ice) |
| 289 |
} else {
|
|
| 290 | 2x |
dat_subs <- NULL |
| 291 |
} |
|
| 292 |
} |
|
| 293 | ||
| 294 |
# calculate ice |
|
| 295 | 56x |
ice <- calculate_ice_data( |
| 296 | 56x |
sub = dat_subs, |
| 297 | 56x |
predict_fun = model_predict_fun, |
| 298 | 56x |
x = dat[[col]], |
| 299 | 56x |
x_name = col, |
| 300 | 56x |
col_order = colnames(dat) |
| 301 |
) |
|
| 302 | ||
| 303 |
# calculate pdp |
|
| 304 | 56x |
pdp <- calculate_pdp_data(id = ice) |
| 305 | ||
| 306 |
# continuous flag |
|
| 307 | 56x |
is_continuous <- is.numeric(dat[[col]]) |
| 308 | ||
| 309 |
# interpolation function - used for making predictions |
|
| 310 | 56x |
af <- y <- x <- NULL # due to NSE notes in R CMD check |
| 311 | 56x |
if (is_continuous && nrow(pdp) > 1) {
|
| 312 | 48x |
e_predict_fun <- new.env(parent = globalenv()) |
| 313 | 48x |
e_predict_fun$x <- pdp[["x"]] |
| 314 | 48x |
e_predict_fun$y <- pdp[["pdp_centered"]] |
| 315 | 48x |
e_predict_fun$af <- approxfun(x = e_predict_fun$x, y = e_predict_fun$y, rule = 2) |
| 316 | 48x |
predict_fun <- function(v) af(v) |
| 317 | 48x |
environment(predict_fun) <- e_predict_fun |
| 318 |
} else {
|
|
| 319 | 8x |
e_predict_fun <- new.env(parent = globalenv()) |
| 320 | 8x |
e_predict_fun$x <- pdp[["x"]] |
| 321 | 8x |
e_predict_fun$y <- pdp[["pdp_centered"]] |
| 322 | 8x |
predict_fun <- function(v) {
|
| 323 | 16x |
ind <- match(v, x) |
| 324 | 16x |
ifelse(is.na(ind), 0, y[ind]) |
| 325 |
} |
|
| 326 | 8x |
environment(predict_fun) <- e_predict_fun |
| 327 |
} |
|
| 328 | ||
| 329 | 56x |
return(list( |
| 330 | 56x |
subsets = dat_subs, |
| 331 | 56x |
predict = predict_fun, |
| 332 | 56x |
ice_centered = split(ice$ice_centered, ice$line_id), |
| 333 | 56x |
ice = split(ice$ice, ice$line_id), |
| 334 | 56x |
is_discrete = !is_continuous, |
| 335 | 56x |
x = dat[[col]], |
| 336 | 56x |
x_name = col |
| 337 |
)) |
|
| 338 |
} |
|
| 339 | ||
| 340 | 11x |
names(res) <- covariates |
| 341 | 11x |
attr(res, "offset") <- mean(predictions) |
| 342 | 11x |
class(res) <- c("rough", "sculpture", class(res))
|
| 343 | ||
| 344 |
# evaluate the sculpture |
|
| 345 | 11x |
es <- eval_sculpture( |
| 346 | 11x |
sculpture = res, |
| 347 | 11x |
data = as.data.frame(as.data.table(lapply(res, "[[", "x"))) |
| 348 |
) |
|
| 349 | ||
| 350 |
# calculate variable importance |
|
| 351 | 11x |
dat_var <- calc_dir_var_imp_pdp(es$pdp) |
| 352 | 11x |
feat_order <- levels(dat_var$feature) |
| 353 | ||
| 354 |
# calculate cumulative R2 |
|
| 355 | 11x |
dat_R2_cumul <- calc_cumul_R2_pdp( |
| 356 | 11x |
dt = es$pdp, |
| 357 | 11x |
feat_order = feat_order, |
| 358 | 11x |
model_predictions = es$prediction$pred, |
| 359 | 11x |
model_offset = es$offset |
| 360 |
) |
|
| 361 | ||
| 362 |
# calculate range |
|
| 363 | 11x |
dat_range <- calc_range_pdp(es$pdp) |
| 364 | ||
| 365 | 11x |
attr(res, "var_imp") <- dat_var |
| 366 | 11x |
attr(res, "cumul_R2") <- dat_R2_cumul |
| 367 | 11x |
attr(res, "range") <- dat_range |
| 368 | ||
| 369 | 11x |
return(res) |
| 370 |
} |
|
| 371 | ||
| 372 | ||
| 373 | ||
| 374 |
# detailed sculpture -------- |
|
| 375 | ||
| 376 |
#' Create a detailed model with user defined smoother |
|
| 377 |
#' |
|
| 378 |
#' @param rs Rough model, i.e. object of classes `rough` and `sculpture`. |
|
| 379 |
#' @param smoother_fit Smoother fitting function. |
|
| 380 |
#' @param smoother_predict Smoother prediction function. |
|
| 381 |
#' @param missings (`NULL`) or single value or a named vector. |
|
| 382 |
#' Specifies the value(-s) that stand for the missing values. |
|
| 383 |
#' If `NULL`, then no missing value handling is carried out. |
|
| 384 |
#' If single value, then it is assumed that this value is used for flagging missing values across |
|
| 385 |
#' all continuous variables. |
|
| 386 |
#' If named vector, then the names are used to refer to continuous variables and the values for |
|
| 387 |
#' flagging missing values in that variable. |
|
| 388 |
#' @param verbose (`integer`) 0 for silent run, > 0 for messages. |
|
| 389 |
#' @param allow_par (`logical`) Allow parallel computation? Defaults to `FALSE`. |
|
| 390 |
#' |
|
| 391 |
#' @details For parallel computation, use [parallel_set()] and set `allow_par` to `TRUE`. |
|
| 392 |
#' Note that parallel computation may fail if the model is too big and there is not enough memory. |
|
| 393 |
#' |
|
| 394 |
#' @section Custom smoothers: |
|
| 395 |
#' If none of the predefined smoothers ([sculpt_detailed_gam()], [sculpt_detailed_lm()]) |
|
| 396 |
#' suits your needs, you can define your own smoothers. |
|
| 397 |
#' You need to define 2 functions: `smoother_fit` and `smoother_predict`: |
|
| 398 |
#' |
|
| 399 |
#' `smoother_fit` takes 5 arguments ("x", "y", "is_discrete", "column_name", "na_ind") and
|
|
| 400 |
#' returns a model fit. "x" are the feature values, "y" are the PDP values, |
|
| 401 |
#' "is_discrete" flags a discrete feature, "column_name" holds the feature name, |
|
| 402 |
#' and "na_ind" passes the NA value from `missings` (or NULL by default). |
|
| 403 |
#' |
|
| 404 |
#' `smoother_predict` takes also 5 arguments ("smoother", "new_x", "is_discrete", "column_name",
|
|
| 405 |
#' "na_ind") and returns predictions as a vector. "smoother" is the model fit returned from |
|
| 406 |
#' `smoother_fit`, "new_x" are the feature values that we want to predict, "is_discrete", |
|
| 407 |
#' "column_name", and "na_ind" have the same purpose as in `smoother_fit`. |
|
| 408 |
#' See also Examples. |
|
| 409 |
#' |
|
| 410 |
#' @return Object of classes `detailed` and `sculpture`. |
|
| 411 |
#' @export |
|
| 412 |
#' |
|
| 413 |
#' @examples |
|
| 414 |
#' df <- mtcars |
|
| 415 |
#' df$vs <- as.factor(df$vs) |
|
| 416 |
#' model <- rpart::rpart( |
|
| 417 |
#' hp ~ mpg + carb + vs, |
|
| 418 |
#' data = df, |
|
| 419 |
#' control = rpart::rpart.control(minsplit = 10) |
|
| 420 |
#' ) |
|
| 421 |
#' model_predict <- function(x) predict(model, newdata = x) |
|
| 422 |
#' covariates <- c("mpg", "carb", "vs")
|
|
| 423 |
#' pm <- sample_marginals(df[covariates], n = 50, seed = 5) |
|
| 424 |
#' |
|
| 425 |
#' rs <- sculpt_rough( |
|
| 426 |
#' dat = pm, |
|
| 427 |
#' model_predict_fun = model_predict, |
|
| 428 |
#' n_ice = 10, |
|
| 429 |
#' seed = 1, |
|
| 430 |
#' verbose = 0 |
|
| 431 |
#' ) |
|
| 432 |
#' |
|
| 433 |
#' # define custom smoother |
|
| 434 |
#' # - gam with 3 knots for variable "mpg" |
|
| 435 |
#' # - gam with 5 knots for variable "carb" |
|
| 436 |
#' # - lm for any discrete variable |
|
| 437 |
#' library(mgcv) |
|
| 438 |
#' my_smoother <- function(x, y, is_discrete, column_name, na_ind = NULL) {
|
|
| 439 |
#' if (column_name == "mpg") {
|
|
| 440 |
#' gam(y ~ s(x, k = 3)) |
|
| 441 |
#' } else if (column_name == "carb") {
|
|
| 442 |
#' gam(y ~ s(x, k = 5)) |
|
| 443 |
#' } else if (is_discrete) {
|
|
| 444 |
#' lm(y ~ x) |
|
| 445 |
#' } else {
|
|
| 446 |
#' stop("Undefined smoother")
|
|
| 447 |
#' } |
|
| 448 |
#' } |
|
| 449 |
#' |
|
| 450 |
#' # define appropriate predict function |
|
| 451 |
#' # - predict.gam returns an array, we need to convert it to vector |
|
| 452 |
#' # - if-else branch for illustration purposes |
|
| 453 |
#' my_smoother_predict <- function(smoother, new_x, is_discrete, column_name, na_ind = NULL) {
|
|
| 454 |
#' if (inherits(smoother, "gam")) {
|
|
| 455 |
#' # as.numeric: convert array to vector |
|
| 456 |
#' as.numeric(predict(smoother, newdata = data.frame(x = new_x))) |
|
| 457 |
#' } else {
|
|
| 458 |
#' predict(smoother, newdata = data.frame(x = new_x)) |
|
| 459 |
#' } |
|
| 460 |
#' } |
|
| 461 |
#' |
|
| 462 |
#' ds <- sculpt_detailed_generic( |
|
| 463 |
#' rs = rs, |
|
| 464 |
#' smoother_fit = my_smoother, |
|
| 465 |
#' smoother_predict = my_smoother_predict |
|
| 466 |
#' ) |
|
| 467 |
#' class(ds) |
|
| 468 |
#' \dontrun{
|
|
| 469 |
#' # see components |
|
| 470 |
#' g_component(ds)$continuous |
|
| 471 |
#' } |
|
| 472 |
#' |
|
| 473 |
#' |
|
| 474 |
#' # another example with constrained gam (cgam) package |
|
| 475 |
#' \dontrun{
|
|
| 476 |
#' library(cgam) |
|
| 477 |
#' |
|
| 478 |
#' cgam_smoother <- function(x, y, is_discrete, column_name, na_ind = NULL) {
|
|
| 479 |
#' if (column_name == "carb") {
|
|
| 480 |
#' cgam(y ~ s.incr(x, numknots = 3)) |
|
| 481 |
#' } else if (column_name == "mpg") {
|
|
| 482 |
#' cgam(y ~ s.decr(x, numknots = 3)) |
|
| 483 |
#' } else {
|
|
| 484 |
#' cgam(y ~ x) |
|
| 485 |
#' } |
|
| 486 |
#' } |
|
| 487 |
#' |
|
| 488 |
#' cgam_predict <- function(smoother, new_x, is_discrete, column_name, na_ind = NULL) {
|
|
| 489 |
#' predict(smoother, newData = data.frame(x = new_x))$fit |
|
| 490 |
#' } |
|
| 491 |
#' |
|
| 492 |
#' ds2 <- sculpt_detailed_generic( |
|
| 493 |
#' rs = rs, |
|
| 494 |
#' smoother_fit = cgam_smoother, |
|
| 495 |
#' smoother_predict = cgam_predict |
|
| 496 |
#' ) |
|
| 497 |
#' |
|
| 498 |
#' # see components |
|
| 499 |
#' g_component(ds2)$continuous |
|
| 500 |
#' } |
|
| 501 |
sculpt_detailed_generic <- function(rs, smoother_fit, smoother_predict, |
|
| 502 |
missings = NULL, verbose = 0, allow_par = FALSE) {
|
|
| 503 | 5x |
checkmate::assert_class(rs, "sculpture") |
| 504 | 5x |
checkmate::assert_class(rs, "rough") |
| 505 | 5x |
checkmate::assert_function( |
| 506 | 5x |
smoother_fit, |
| 507 | 5x |
args = c("x", "y", "is_discrete", "column_name", "na_ind")
|
| 508 |
) |
|
| 509 | 5x |
checkmate::assert_function( |
| 510 | 5x |
smoother_predict, |
| 511 | 5x |
args = c("smoother", "new_x", "is_discrete", "column_name", "na_ind")
|
| 512 |
) |
|
| 513 | 5x |
checkmate::assert( |
| 514 | 5x |
checkmate::check_null(missings), |
| 515 | 5x |
checkmate::check_atomic(missings, any.missing = FALSE, len = 1), |
| 516 | 5x |
checkmate::check_atomic(missings, any.missing = FALSE, max.len = length(rs), names = "named") |
| 517 |
) |
|
| 518 | 5x |
check_continuous <- vapply(rs, "[[", logical(1), "is_discrete") |
| 519 | 5x |
check_continuous <- names(Filter(isFALSE, check_continuous)) |
| 520 | 5x |
if (length(missings) == 1) {
|
| 521 | ! |
missings <- rep(list(missings), length(check_continuous)) |
| 522 | ! |
names(missings) <- check_continuous |
| 523 | 5x |
} else if (length(missings) != 0) {
|
| 524 | ! |
missings <- as.list(missings) |
| 525 | ! |
checkmate::assert_subset(names(missings), check_continuous, .var.name = "missings") |
| 526 |
} |
|
| 527 | 5x |
checkmate::assert_integerish(verbose, lower = 0, any.missing = FALSE, len = 1) |
| 528 | 5x |
checkmate::assert_flag(allow_par) |
| 529 | ||
| 530 |
# https://cran.r-project.org/web/packages/data.table/vignettes/datatable-importing.html#globals |
|
| 531 | 5x |
x <- NULL # due to NSE notes in R CMD check |
| 532 | ||
| 533 | 5x |
`%operand%` <- define_foreach_operand(allow_par = allow_par) |
| 534 | 5x |
res <- foreach::foreach(col = names(rs)) %operand% {
|
| 535 |
# verbosity ... |
|
| 536 | 27x |
matched <- match(col, names(rs)) |
| 537 | 27x |
if (verbose > 0) {
|
| 538 | ! |
if (length(rs) < 10) {
|
| 539 | ! |
message(paste("Sculpting variable:", matched, "/", length(rs)))
|
| 540 |
} else {
|
|
| 541 | ! |
if ((matched == 1) | (matched %% 10 == 0)) {
|
| 542 | ! |
message(paste("Sculpting variable:", matched, "/", length(rs)))
|
| 543 |
} |
|
| 544 |
} |
|
| 545 |
} |
|
| 546 | ||
| 547 |
# build the smoother from PDPs (based on original data, i.e. with duplicates) |
|
| 548 | 27x |
pdp_dupl <- data.table(x = rs[[col]]$x, pdp_centered = rs[[col]]$predict(rs[[col]]$x)) |
| 549 | 27x |
pdp_dupl <- pdp_dupl[order(x)] |
| 550 | ||
| 551 |
# memory optimization: use a clean environment for the predict function |
|
| 552 | 27x |
e_predict_fun <- new.env() |
| 553 | ||
| 554 |
# estimate smoothers |
|
| 555 | 27x |
e_predict_fun$smoother <- smoother_fit( |
| 556 | 27x |
x = pdp_dupl$x, |
| 557 | 27x |
y = pdp_dupl$pdp_centered, |
| 558 | 27x |
is_discrete = rs[[col]]$is_discrete, |
| 559 | 27x |
column_name = col, |
| 560 | 27x |
na_ind = `if`(!is.null(missings[[col]]), pdp_dupl$x == missings[[col]]) |
| 561 |
) |
|
| 562 | ||
| 563 |
# add the rest of variables into the function environment |
|
| 564 | 27x |
e_predict_fun$smoother_predict <- smoother_predict |
| 565 | 27x |
e_predict_fun$is_discrete <- rs[[col]]$is_discrete |
| 566 | 27x |
e_predict_fun$col <- col |
| 567 | 27x |
e_predict_fun$missings_flag <- missings[[col]] |
| 568 | ||
| 569 |
# smoother prediction function |
|
| 570 | 27x |
smoother <- is_discrete <- missings_flag <- NULL # due to NSE notes in R CMD check |
| 571 | 27x |
predict_fun <- function(x) {
|
| 572 | 57x |
smoother_predict( |
| 573 | 57x |
smoother = smoother, |
| 574 | 57x |
new_x = x, |
| 575 | 57x |
is_discrete = is_discrete, |
| 576 | 57x |
column_name = col, |
| 577 | 57x |
na_ind = `if`(!is.null(missings_flag), x == missings_flag) |
| 578 |
) |
|
| 579 |
} |
|
| 580 |
# use the defined environment as the environment of the function |
|
| 581 | 27x |
environment(predict_fun) <- e_predict_fun |
| 582 | ||
| 583 |
# check the output of the smoother prediction |
|
| 584 | 27x |
pf_check <- predict_fun(pdp_dupl$x) |
| 585 | 27x |
if (!is.vector(pf_check) || is.character(pf_check)) {
|
| 586 | ! |
stop("The output of the `smoother_predict` needs to be a numeric/factor vector.")
|
| 587 |
} |
|
| 588 | ||
| 589 | 27x |
return(list( |
| 590 | 27x |
predict = predict_fun, |
| 591 | 27x |
is_discrete = rs[[col]]$is_discrete, |
| 592 | 27x |
x = rs[[col]]$x, |
| 593 | 27x |
x_name = rs[[col]]$x_name, |
| 594 | 27x |
missings_flag = missings[[col]] |
| 595 |
)) |
|
| 596 |
} |
|
| 597 | ||
| 598 | 5x |
names(res) <- names(rs) |
| 599 | 5x |
attr(res, "offset") <- attr(rs, "offset") |
| 600 | 5x |
class(res) <- c("detailed", "sculpture", class(res))
|
| 601 | ||
| 602 |
# evaluate the sculpture |
|
| 603 | 5x |
es <- eval_sculpture( |
| 604 | 5x |
sculpture = res, |
| 605 | 5x |
data = as.data.frame(as.data.table(lapply(res, "[[", "x"))) |
| 606 |
) |
|
| 607 | ||
| 608 |
# calculate variable importance |
|
| 609 | 5x |
dat_var <- calc_dir_var_imp_pdp(es$pdp) |
| 610 | 5x |
feat_order <- levels(dat_var$feature) |
| 611 | ||
| 612 |
# calculate cumulative R2 |
|
| 613 | 5x |
dat_R2_cumul <- calc_cumul_R2_pdp( |
| 614 | 5x |
dt = es$pdp, |
| 615 | 5x |
feat_order = feat_order, |
| 616 | 5x |
model_predictions = es$prediction$pred, |
| 617 | 5x |
model_offset = es$offset |
| 618 |
) |
|
| 619 | ||
| 620 |
# calculate range |
|
| 621 | 5x |
dat_range <- calc_range_pdp(es$pdp) |
| 622 | ||
| 623 | 5x |
attr(res, "var_imp") <- dat_var |
| 624 | 5x |
attr(res, "cumul_R2") <- dat_R2_cumul |
| 625 | 5x |
attr(res, "range") <- dat_range |
| 626 | ||
| 627 | 5x |
return(res) |
| 628 |
} |
|
| 629 | ||
| 630 | ||
| 631 |
smoother_gam <- function(x, y, is_discrete, column_name, na_ind = NULL) {
|
|
| 632 | 10x |
s <- mgcv::s |
| 633 | 10x |
if (!is_discrete) {
|
| 634 | 9x |
tryCatch( |
| 635 | 9x |
mgcv::gam(as.formula(paste0("y ~ s(x, k = -1)", `if`(!is.null(na_ind), " + na_ind")))),
|
| 636 | 9x |
error = function(e) {
|
| 637 | 4x |
tryCatch( |
| 638 | 4x |
mgcv::gam(as.formula(paste0("y ~ s(x, k = 3)", `if`(!is.null(na_ind), " + na_ind")))),
|
| 639 | 4x |
error = function(e) {
|
| 640 | 2x |
if (length(x) == 1) {
|
| 641 | ! |
lm(as.formula(paste0("y ~ x", `if`(!is.null(na_ind), " + na_ind"))))
|
| 642 |
} else {
|
|
| 643 | 2x |
mgcv::gam(as.formula(paste0("y ~ x", `if`(!is.null(na_ind), " + na_ind"))))
|
| 644 |
} |
|
| 645 |
} |
|
| 646 |
) |
|
| 647 |
} |
|
| 648 |
) |
|
| 649 |
} else {
|
|
| 650 | 1x |
tryCatch( |
| 651 | 1x |
mgcv::gam(y ~ x), |
| 652 | 1x |
error = function(e) {
|
| 653 | ! |
if (length(unique(x)) == 1) {
|
| 654 | ! |
lm(y ~ 0) |
| 655 |
} else {
|
|
| 656 | ! |
stop(paste( |
| 657 | ! |
"Cannot fit a smoother for", column_name, |
| 658 | ! |
"The error message is:", e$message |
| 659 |
)) |
|
| 660 |
} |
|
| 661 |
} |
|
| 662 |
) |
|
| 663 |
} |
|
| 664 |
} |
|
| 665 | ||
| 666 |
smoother_gam_predict <- function(smoother, new_x, is_discrete, column_name, na_ind = NULL) {
|
|
| 667 | 21x |
newdata <- data.frame(x = new_x) |
| 668 | 21x |
newdata$na_ind <- na_ind |
| 669 | 21x |
tryCatch( |
| 670 | 21x |
as.numeric(predict(smoother, newdata = newdata)), |
| 671 | 21x |
warning = function(w) {
|
| 672 | ! |
if (grepl("^factor levels .* not in original fit$", w$message)) {
|
| 673 | ! |
idx_known <- new_x %in% smoother$model$x |
| 674 | ! |
y <- vector("numeric", length = length(new_x))
|
| 675 | ! |
y[idx_known] <- as.numeric(predict(smoother, newdata = newdata[idx_known, , drop = FALSE])) |
| 676 | ! |
y[!idx_known] <- 0 |
| 677 | ! |
return(y) |
| 678 |
} else {
|
|
| 679 | ! |
stop("Unknown value for prediction")
|
| 680 |
} |
|
| 681 |
} |
|
| 682 |
) |
|
| 683 |
} |
|
| 684 | ||
| 685 | ||
| 686 |
#' Create a detailed model with gam smoother |
|
| 687 |
#' |
|
| 688 |
#' @inheritParams sculpt_detailed_generic |
|
| 689 |
#' |
|
| 690 |
#' @details For parallel computation, use [parallel_set()] and set `allow_par` to `TRUE`. |
|
| 691 |
#' Note that parallel computation may fail if the model is too big and there is not enough memory. |
|
| 692 |
#' |
|
| 693 |
#' @return Object of classes `detailed` and `sculpture`. |
|
| 694 |
#' @export |
|
| 695 |
#' |
|
| 696 |
#' @examples |
|
| 697 |
#' df <- mtcars |
|
| 698 |
#' df$vs <- as.factor(df$vs) |
|
| 699 |
#' model <- rpart::rpart( |
|
| 700 |
#' hp ~ mpg + carb + vs, |
|
| 701 |
#' data = df, |
|
| 702 |
#' control = rpart::rpart.control(minsplit = 10) |
|
| 703 |
#' ) |
|
| 704 |
#' model_predict <- function(x) predict(model, newdata = x) |
|
| 705 |
#' covariates <- c("mpg", "carb", "vs")
|
|
| 706 |
#' pm <- sample_marginals(df[covariates], n = 50, seed = 5) |
|
| 707 |
#' |
|
| 708 |
#' rs <- sculpt_rough( |
|
| 709 |
#' dat = pm, |
|
| 710 |
#' model_predict_fun = model_predict, |
|
| 711 |
#' n_ice = 10, |
|
| 712 |
#' seed = 1, |
|
| 713 |
#' verbose = 0 |
|
| 714 |
#' ) |
|
| 715 |
#' |
|
| 716 |
#' ds <- sculpt_detailed_gam(rs) |
|
| 717 |
#' class(ds) |
|
| 718 |
#' |
|
| 719 |
sculpt_detailed_gam <- function(rs, missings = NULL, verbose = 0, allow_par = FALSE) {
|
|
| 720 | 1x |
requireNamespace("mgcv")
|
| 721 | 1x |
sculpt_detailed_generic( |
| 722 | 1x |
rs = rs, verbose = verbose, |
| 723 | 1x |
allow_par = allow_par, |
| 724 | 1x |
smoother_fit = smoother_gam, |
| 725 | 1x |
smoother_predict = smoother_gam_predict, |
| 726 | 1x |
missings = missings |
| 727 |
) |
|
| 728 |
} |
|
| 729 | ||
| 730 | ||
| 731 | ||
| 732 | ||
| 733 |
smoother_lm <- function(x, y, is_discrete, column_name, na_ind = NULL) {
|
|
| 734 | 17x |
tryCatch( |
| 735 | 17x |
lm(as.formula(paste0("y ~ x", `if`(!is.null(na_ind), " + na_ind")))),
|
| 736 | 17x |
error = function(e) {
|
| 737 | ! |
if (length(unique(x)) == 1) {
|
| 738 | ! |
lm(y ~ 0) |
| 739 |
} else {
|
|
| 740 | ! |
stop(paste( |
| 741 | ! |
"Cannot fit a smoother for", column_name, |
| 742 | ! |
"The error message is:", e$message |
| 743 |
)) |
|
| 744 |
} |
|
| 745 |
} |
|
| 746 |
) |
|
| 747 |
} |
|
| 748 | ||
| 749 |
smoother_lm_predict <- function(smoother, new_x, is_discrete, column_name, na_ind = NULL) {
|
|
| 750 | 36x |
newdata <- data.frame(x = new_x) |
| 751 | 36x |
newdata$na_ind <- na_ind |
| 752 | 36x |
tryCatch( |
| 753 | 36x |
unname(predict(smoother, newdata = newdata)), |
| 754 | 36x |
error = function(e) {
|
| 755 | ! |
if (is_discrete && grepl("factor x has new level", e$message)) {
|
| 756 | ! |
idx_known <- new_x %in% smoother$model$x |
| 757 | ! |
y <- vector("numeric", length = length(new_x))
|
| 758 | ! |
y[idx_known] <- as.numeric(predict(smoother, newdata = newdata[idx_known, , drop = FALSE])) |
| 759 | ! |
y[!idx_known] <- 0 |
| 760 | ! |
return(y) |
| 761 |
} else {
|
|
| 762 | ! |
stop("Unknown value for prediction")
|
| 763 |
} |
|
| 764 |
} |
|
| 765 |
) |
|
| 766 |
} |
|
| 767 | ||
| 768 | ||
| 769 |
#' Create a detailed model with lm smoother |
|
| 770 |
#' |
|
| 771 |
#' @inheritParams sculpt_detailed_generic |
|
| 772 |
#' |
|
| 773 |
#' @details For parallel computation, use [parallel_set()] and set `allow_par` to `TRUE`. |
|
| 774 |
#' Note that parallel computation may fail if the model is too big and there is not enough memory. |
|
| 775 |
#' |
|
| 776 |
#' @return Object of classes `detailed` and `sculpture`. |
|
| 777 |
#' @export |
|
| 778 |
#' |
|
| 779 |
#' @examples |
|
| 780 |
#' df <- mtcars |
|
| 781 |
#' df$vs <- as.factor(df$vs) |
|
| 782 |
#' model <- rpart::rpart( |
|
| 783 |
#' hp ~ mpg + carb + vs, |
|
| 784 |
#' data = df, |
|
| 785 |
#' control = rpart::rpart.control(minsplit = 10) |
|
| 786 |
#' ) |
|
| 787 |
#' model_predict <- function(x) predict(model, newdata = x) |
|
| 788 |
#' covariates <- c("mpg", "carb", "vs")
|
|
| 789 |
#' pm <- sample_marginals(df[covariates], n = 50, seed = 5) |
|
| 790 |
#' |
|
| 791 |
#' rs <- sculpt_rough( |
|
| 792 |
#' dat = pm, |
|
| 793 |
#' model_predict_fun = model_predict, |
|
| 794 |
#' n_ice = 10, |
|
| 795 |
#' seed = 1, |
|
| 796 |
#' verbose = 0 |
|
| 797 |
#' ) |
|
| 798 |
#' |
|
| 799 |
#' ds <- sculpt_detailed_lm(rs) |
|
| 800 |
#' class(ds) |
|
| 801 |
#' |
|
| 802 |
sculpt_detailed_lm <- function(rs, missings = NULL, verbose = 0, allow_par = FALSE) {
|
|
| 803 | 4x |
sculpt_detailed_generic( |
| 804 | 4x |
rs = rs, verbose = verbose, |
| 805 | 4x |
allow_par = allow_par, |
| 806 | 4x |
smoother_fit = smoother_lm, |
| 807 | 4x |
smoother_predict = smoother_lm_predict, |
| 808 | 4x |
missings = missings |
| 809 |
) |
|
| 810 |
} |
|
| 811 | ||
| 812 | ||
| 813 |
# polished sculpture -------- |
|
| 814 | ||
| 815 | ||
| 816 |
#' Create a polished model |
|
| 817 |
#' |
|
| 818 |
#' @param object Object of class `sculpture`, either `rough` or `detailed`. |
|
| 819 |
#' @param k Number of most important variables to keep. |
|
| 820 |
#' @param vars Vector of variables to keep. |
|
| 821 |
#' |
|
| 822 |
#' @return Object of classes `rough` / `detailed` and `sculpture`. |
|
| 823 |
#' @export |
|
| 824 |
#' |
|
| 825 |
#' @examples |
|
| 826 |
#' df <- mtcars |
|
| 827 |
#' df$vs <- as.factor(df$vs) |
|
| 828 |
#' model <- rpart::rpart( |
|
| 829 |
#' hp ~ mpg + carb + vs, |
|
| 830 |
#' data = df, |
|
| 831 |
#' control = rpart::rpart.control(minsplit = 10) |
|
| 832 |
#' ) |
|
| 833 |
#' model_predict <- function(x) predict(model, newdata = x) |
|
| 834 |
#' covariates <- c("mpg", "carb", "vs")
|
|
| 835 |
#' pm <- sample_marginals(df[covariates], n = 50, seed = 5) |
|
| 836 |
#' |
|
| 837 |
#' rs <- sculpt_rough( |
|
| 838 |
#' dat = pm, |
|
| 839 |
#' model_predict_fun = model_predict, |
|
| 840 |
#' n_ice = 10, |
|
| 841 |
#' seed = 1, |
|
| 842 |
#' verbose = 0 |
|
| 843 |
#' ) |
|
| 844 |
#' |
|
| 845 |
#' ds <- sculpt_detailed_gam(rs) |
|
| 846 |
#' |
|
| 847 |
#' # this keeps only "mpg" |
|
| 848 |
#' ps <- sculpt_polished(ds, k = 1) |
|
| 849 |
#' |
|
| 850 |
sculpt_polished <- function(object, k = NULL, vars = NULL) {
|
|
| 851 | ! |
checkmate::assert_class(object, "sculpture") |
| 852 | ! |
checkmate::assert( |
| 853 | ! |
checkmate::check_null(k), |
| 854 | ! |
checkmate::check_null(vars) |
| 855 |
) |
|
| 856 | ||
| 857 | ! |
if (is.null(k)) {
|
| 858 | ! |
checkmate::assert_character(vars, min.len = 1, any.missing = FALSE) |
| 859 | ! |
checkmate::assert_subset(vars, names(object)) |
| 860 | ! |
} else if (is.null(vars)) {
|
| 861 | ! |
checkmate::assert_number(k, lower = 1) |
| 862 | ! |
vars <- levels(attr(object, "cumul_R2")$feature)[1:k] |
| 863 |
} |
|
| 864 | ||
| 865 | ! |
res <- object[vars] |
| 866 | ! |
attr(res, "offset") <- attr(object, "offset") |
| 867 | ! |
class(res) <- class(object) |
| 868 | ||
| 869 |
# evaluate the sculpture |
|
| 870 | ! |
es <- eval_sculpture( |
| 871 | ! |
sculpture = res, |
| 872 | ! |
data = as.data.frame(as.data.table(lapply(res, "[[", "x"))) |
| 873 |
) |
|
| 874 | ||
| 875 |
# calculate variable importance |
|
| 876 | ! |
dat_var <- calc_dir_var_imp_pdp(es$pdp) |
| 877 | ! |
feat_order <- levels(dat_var$feature) |
| 878 | ||
| 879 |
# calculate cumulative R2 |
|
| 880 | ! |
dat_R2_cumul <- calc_cumul_R2_pdp( |
| 881 | ! |
dt = es$pdp, |
| 882 | ! |
feat_order = feat_order, |
| 883 | ! |
model_predictions = es$prediction$pred, |
| 884 | ! |
model_offset = es$offset |
| 885 |
) |
|
| 886 | ||
| 887 |
# calculate range |
|
| 888 | ! |
dat_range <- calc_range_pdp(es$pdp) |
| 889 | ||
| 890 | ! |
attr(res, "var_imp") <- dat_var |
| 891 | ! |
attr(res, "cumul_R2") <- dat_R2_cumul |
| 892 | ! |
attr(res, "range") <- dat_range |
| 893 | ||
| 894 | ! |
return(res) |
| 895 |
} |
|
| 896 | ||
| 897 | ||
| 898 | ||
| 899 |
# utils ----- |
|
| 900 | ||
| 901 | ||
| 902 |
eval_sculpture <- function(sculpture, data) {
|
|
| 903 | 25x |
stopifnot( |
| 904 | 25x |
inherits(sculpture, "sculpture"), |
| 905 | 25x |
ncol(data) >= length(sculpture), |
| 906 | 25x |
all(names(sculpture) %in% colnames(data)) |
| 907 |
) |
|
| 908 | ||
| 909 |
# https://cran.r-project.org/web/packages/data.table/vignettes/datatable-importing.html#globals |
|
| 910 | 25x |
. <- rn <- pdp_c <- NULL # due to NSE notes in R CMD check |
| 911 | ||
| 912 |
# get offset for predictions |
|
| 913 | 25x |
offset <- attr(sculpture, "offset") |
| 914 | ||
| 915 |
# get predict functions |
|
| 916 | 25x |
interp_funs <- lapply(sculpture, "[[", "predict") |
| 917 | ||
| 918 |
# PDPs at data |
|
| 919 | 25x |
pdp <- lapply(names(sculpture), function(col) interp_funs[[col]](data[[col]])) |
| 920 | 25x |
names(pdp) <- names(sculpture) |
| 921 | ||
| 922 |
# reshape PDPs to create predictions |
|
| 923 | 25x |
pdp <- cbind(data.table(rn = 1:nrow(data)), do.call("cbind", pdp))
|
| 924 | 25x |
pdp <- melt(pdp, id.vars = "rn", variable.name = "feature", value.name = "pdp_c") |
| 925 | 25x |
pred <- pdp[, .(pred = sum(pdp_c) + offset), .(rn)][order(rn)] |
| 926 | ||
| 927 | 25x |
return(list( |
| 928 | 25x |
pdp = pdp, |
| 929 | 25x |
offset = offset, |
| 930 | 25x |
prediction = pred |
| 931 |
)) |
|
| 932 |
} |
|
| 933 | ||
| 934 |
#' @export |
|
| 935 |
predict.sculpture <- function(object, newdata = NULL, ...) {
|
|
| 936 | 8x |
if (is.null(newdata)) {
|
| 937 | ! |
newdata <- as.data.frame(as.data.table(lapply(object, "[[", "x"))) |
| 938 |
} else {
|
|
| 939 | 8x |
checkmate::assert_subset(names(object), colnames(newdata)) |
| 940 | 8x |
newdata <- as.data.frame(newdata) |
| 941 |
} |
|
| 942 | 8x |
tmp <- eval_sculpture( |
| 943 | 8x |
sculpture = object, |
| 944 | 8x |
data = newdata[names(object)] |
| 945 |
) |
|
| 946 | 8x |
return(structure(tmp$prediction$pred, names = tmp$prediction$rn)) |
| 947 |
} |
|
| 948 | ||
| 949 |
#' @export |
|
| 950 |
print.sculpture <- function(x, ...) {
|
|
| 951 | ! |
n_vars <- length(x) |
| 952 | ! |
cat( |
| 953 | ! |
paste( |
| 954 | ! |
stringr::str_to_sentence(class(x)[1]), "sculpture with", |
| 955 | ! |
n_vars, paste0("variable", `if`(n_vars > 1, "s"))
|
| 956 |
) |
|
| 957 |
) |
|
| 958 |
} |
|
| 959 | ||
| 960 |
# for transforming log-odds to probability |
|
| 961 | 17x |
inv.logit <- function(x) 1 / (1 + exp(-x)) |
| 1 |
# sculpture metrics -------- |
|
| 2 | ||
| 3 | ||
| 4 |
#' Various metrics related to model sculpting |
|
| 5 |
#' |
|
| 6 |
#' @name var_imp |
|
| 7 |
#' |
|
| 8 |
#' @param object `sculpture` |
|
| 9 |
#' @param newdata (Optional) Data to calculate the importance from. |
|
| 10 |
#' If omitted, the data that were provided to build the sculpture are used. |
|
| 11 |
#' |
|
| 12 |
#' @return `data.table` with direct requested metrics. |
|
| 13 |
#' |
|
| 14 |
#' @examples |
|
| 15 |
#' df <- mtcars |
|
| 16 |
#' df$vs <- as.factor(df$vs) |
|
| 17 |
#' model <- rpart::rpart( |
|
| 18 |
#' hp ~ mpg + carb + vs, |
|
| 19 |
#' data = df, |
|
| 20 |
#' control = rpart::rpart.control(minsplit = 10) |
|
| 21 |
#' ) |
|
| 22 |
#' model_predict <- function(x) predict(model, newdata = x) |
|
| 23 |
#' covariates <- c("mpg", "carb", "vs")
|
|
| 24 |
#' pm <- sample_marginals(df[covariates], n = 50, seed = 5) |
|
| 25 |
#' |
|
| 26 |
#' rs <- sculpt_rough( |
|
| 27 |
#' dat = pm, |
|
| 28 |
#' model_predict_fun = model_predict, |
|
| 29 |
#' n_ice = 10, |
|
| 30 |
#' seed = 1, |
|
| 31 |
#' verbose = 0 |
|
| 32 |
#' ) |
|
| 33 |
#' |
|
| 34 |
#' # show direct variable importance |
|
| 35 |
#' calc_dir_var_imp(rs) |
|
| 36 |
#' |
|
| 37 |
#' # show cumulative approximation R^2 |
|
| 38 |
#' calc_cumul_R2(rs) |
|
| 39 |
NULL |
|
| 40 | ||
| 41 | ||
| 42 |
calc_dir_var_imp_pdp <- function(dt) {
|
|
| 43 | 17x |
stopifnot( |
| 44 | 17x |
all(c("rn", "feature", "pdp_c") %in% colnames(dt)),
|
| 45 | 17x |
!"total" %in% tolower(unique(dt$feature)), |
| 46 | 17x |
nrow(dt) == nrow(unique(dt[, .(rn, feature)])) |
| 47 |
) |
|
| 48 | ||
| 49 |
# https://cran.r-project.org/web/packages/data.table/vignettes/datatable-importing.html#globals |
|
| 50 | 17x |
. <- rn <- feature <- pdp_c <- ratio <- variance <- variance_total <- |
| 51 | 17x |
NULL # due to NSE notes in R CMD check |
| 52 | ||
| 53 |
# calculate total variance of PDPs |
|
| 54 | 17x |
var_total <- dt[, .(pdp_c = sum(pdp_c)), .(rn)][, var(pdp_c)] |
| 55 | ||
| 56 |
# calculate variance per feature |
|
| 57 | 17x |
dat_var <- dt[ |
| 58 |
, |
|
| 59 | 17x |
.(variance = var(pdp_c), variance_total = var_total), |
| 60 | 17x |
.(feature) |
| 61 |
][ |
|
| 62 |
, |
|
| 63 | 17x |
ratio := variance / variance_total |
| 64 |
][ |
|
| 65 | 17x |
order(ratio, decreasing = TRUE) |
| 66 |
] |
|
| 67 | ||
| 68 |
# define as factor to keep the order |
|
| 69 | 17x |
dat_var[, feature := factor(feature, levels = feature)] |
| 70 | 17x |
return(dat_var) |
| 71 |
} |
|
| 72 | ||
| 73 |
#' @describeIn var_imp Direct variable importance |
|
| 74 |
#' @export |
|
| 75 |
calc_dir_var_imp <- function(object, newdata = NULL) {
|
|
| 76 | ! |
checkmate::assert_class(object, "sculpture") |
| 77 | ! |
if (is.null(newdata)) {
|
| 78 | ! |
return(attr(object, "var_imp")) |
| 79 |
} |
|
| 80 | ! |
checkmate::assert_data_frame(newdata, any.missing = FALSE) |
| 81 | ! |
calc_dir_var_imp_pdp( |
| 82 | ! |
eval_sculpture( |
| 83 | ! |
sculpture = object, |
| 84 | ! |
data = newdata |
| 85 | ! |
)$pdp |
| 86 |
) |
|
| 87 |
} |
|
| 88 | ||
| 89 |
calc_cumul_R2_pdp <- function(dt, feat_order, model_predictions, model_offset) {
|
|
| 90 | 17x |
stopifnot( |
| 91 | 17x |
is.data.table(dt), |
| 92 | 17x |
all(c("rn", "feature", "pdp_c") %in% colnames(dt)),
|
| 93 | 17x |
!"total" %in% tolower(unique(dt$feature)), |
| 94 | 17x |
nrow(dt) == nrow(unique(dt[, .(rn, feature)])), |
| 95 | 17x |
length(model_predictions) == length(unique(dt$rn)), |
| 96 | 17x |
is.character(feat_order), |
| 97 | 17x |
is.numeric(model_offset) |
| 98 |
) |
|
| 99 | ||
| 100 |
# https://cran.r-project.org/web/packages/data.table/vignettes/datatable-importing.html#globals |
|
| 101 | 17x |
. <- rn <- feature <- pdp_c <- preds <- NULL # due to NSE notes in R CMD check |
| 102 | ||
| 103 |
# prepare ordered features |
|
| 104 | 17x |
cumul_features <- lapply(seq_along(feat_order), function(i) feat_order[1:i]) |
| 105 | ||
| 106 |
# calculate R2 |
|
| 107 | 17x |
R2_cumul <- vapply( |
| 108 | 17x |
cumul_features, |
| 109 | 17x |
function(cols) {
|
| 110 | 85x |
predictions <- dt[ |
| 111 | 85x |
feature %in% cols, |
| 112 | 85x |
.(preds = sum(pdp_c) + model_offset), |
| 113 | 85x |
.(rn) |
| 114 |
][ |
|
| 115 | 85x |
order(rn), preds |
| 116 |
] |
|
| 117 | 85x |
metrics_R2(score_fun = "score_quadratic", y = model_predictions, y_hat = predictions) |
| 118 |
}, |
|
| 119 | 17x |
numeric(1) |
| 120 |
) |
|
| 121 | 17x |
return( |
| 122 | 17x |
data.table(feature = factor(feat_order, levels = feat_order), R2 = R2_cumul) |
| 123 |
) |
|
| 124 |
} |
|
| 125 | ||
| 126 | ||
| 127 |
#' @describeIn var_imp Calculate cumulative approximation of R^2 |
|
| 128 |
#' @export |
|
| 129 |
calc_cumul_R2 <- function(object, newdata = NULL) {
|
|
| 130 | ! |
checkmate::assert_class(object, "sculpture") |
| 131 | ! |
if (is.null(newdata)) {
|
| 132 | ! |
return(attr(object, "cumul_R2")) |
| 133 |
} |
|
| 134 | ! |
checkmate::assert_data_frame(newdata, any.missing = FALSE) |
| 135 | ||
| 136 | ! |
eg <- eval_sculpture( |
| 137 | ! |
sculpture = object, |
| 138 | ! |
data = newdata |
| 139 |
) |
|
| 140 | ||
| 141 | ! |
dat_var <- calc_dir_var_imp_pdp(dt = eg$pdp) |
| 142 | ||
| 143 | ! |
calc_cumul_R2_pdp( |
| 144 | ! |
dt = eg$pdp, |
| 145 | ! |
feat_order = levels(dat_var$feature), |
| 146 | ! |
model_predictions = eg$prediction$pred, |
| 147 | ! |
model_offset = eg$offset |
| 148 |
) |
|
| 149 |
} |
|
| 150 | ||
| 151 | ||
| 152 |
# calculate range - for plots (facet sorting) |
|
| 153 |
calc_range_pdp <- function(dt) {
|
|
| 154 | 16x |
stopifnot( |
| 155 | 16x |
all(c("rn", "feature", "pdp_c") %in% colnames(dt)),
|
| 156 | 16x |
nrow(dt) == nrow(unique(dt[, .(rn, feature)])) |
| 157 |
) |
|
| 158 | ||
| 159 |
# https://cran.r-project.org/web/packages/data.table/vignettes/datatable-importing.html#globals |
|
| 160 | 16x |
. <- rn <- feature <- pdp_c <- NULL # due to NSE notes in R CMD check |
| 161 | ||
| 162 | 16x |
dt_range <- dt[ |
| 163 | 16x |
, .(range = max(pdp_c) - min(pdp_c)), |
| 164 | 16x |
.(feature) |
| 165 |
][ |
|
| 166 | 16x |
order(-range) |
| 167 |
] |
|
| 168 | 16x |
dt_range[, feature := factor(feature, levels = feature)][] |
| 169 | ||
| 170 | 16x |
return(dt_range) |
| 171 |
} |
|
| 172 | ||
| 173 | ||
| 174 | ||
| 175 |
# generic metrics -------- |
|
| 176 | ||
| 177 | ||
| 178 |
#' Various metrics for measuring model performance. |
|
| 179 |
#' |
|
| 180 |
#' @name metrics |
|
| 181 |
#' @param score_fun A scoring function: `score_quadratic`, `score_log_loss`, |
|
| 182 |
#' or a user-defined scoring rule. See below for more details. |
|
| 183 |
#' @param y Vector of observations. |
|
| 184 |
#' @param y_hat Vector of predictions. |
|
| 185 |
#' @param y_hat_calib Vector of calibrated predictions. See below for more details. |
|
| 186 |
#' @param na_rm Logical, defaults to `FALSE`. Should NAs be removed? |
|
| 187 |
#' @param rev_fct Logical, defaults to `FALSE`. Switch the factor level of |
|
| 188 |
#' the data before performing calibration. Only relevant for binary response. |
|
| 189 |
#' |
|
| 190 |
#' @section Scoring function: |
|
| 191 |
#' One can use predefined scores like `score_quadratic` or `score_log_loss`. |
|
| 192 |
#' If those do not fit the needs, a user-defined scoring function can also be used. |
|
| 193 |
#' This function needs to take exactly 3 arguments: `y` (truth values), |
|
| 194 |
#' `y_hat` (estimated values), and `na_rm` (should NAs be removed?): |
|
| 195 |
#' - both `y` and `y_hat` are numeric (not factors!) |
|
| 196 |
#' - `na_rm` is a scalar logical |
|
| 197 |
#' |
|
| 198 |
#' It needs to return a number. |
|
| 199 |
#' There is a utility function `check_score_fun` to check if the user-defined function is |
|
| 200 |
#' programmed correctly. |
|
| 201 |
#' It checks the input and the output, but not if the actual returned value makes sense. |
|
| 202 |
#' |
|
| 203 |
#' |
|
| 204 |
#' @section Calibration: |
|
| 205 |
#' To obtain calibrated predictions, |
|
| 206 |
#' fit a calibration model and predict based on that model. |
|
| 207 |
#' Users can use their own calibration model or make use of `metrics_fit_calib`, |
|
| 208 |
#' which fits an `mgcv::gam()` model with smoother `mgcv::s(., k = -1)` (automatic knot selection). |
|
| 209 |
#' If the input `y` is a factor, then a binomial family is used, otherwise a gaussian. |
|
| 210 |
#' NAs are always dropped. |
|
| 211 |
#' |
|
| 212 |
#' Continuous response example: |
|
| 213 |
#' ``` |
|
| 214 |
#' calibration_model <- metrics_fit_calib( |
|
| 215 |
#' y = truth, |
|
| 216 |
#' y_hat = prediction |
|
| 217 |
#' ) |
|
| 218 |
#' calib_pred <- predict(calibration_model) |
|
| 219 |
#' ``` |
|
| 220 |
#' |
|
| 221 |
#' Binary response example: |
|
| 222 |
#' ``` |
|
| 223 |
#' calibration_model <- metrics_fit_calib( |
|
| 224 |
#' y = factor(truth, levels = c("0", "1")),
|
|
| 225 |
#' y_hat = prediction |
|
| 226 |
#' ) |
|
| 227 |
#' calib_pred <- predict(calibration_model, type = "response") |
|
| 228 |
#' ``` |
|
| 229 |
#' In the binary case, make sure that: |
|
| 230 |
#' - `y` is a factor with correct level setting. |
|
| 231 |
#' Usually "0" is the reference (first) level and "1" is the event (second level). |
|
| 232 |
#' This may clash with `yardstick` setting where |
|
| 233 |
#' the first level is by default the "event" level. |
|
| 234 |
#' - `y_hat` are probabilities (not a log of odds). |
|
| 235 |
#' - returned calibrated predictions `calib_pred` are also probabilities by setting |
|
| 236 |
#' `type = "response"`. |
|
| 237 |
#' |
|
| 238 |
#' |
|
| 239 |
#' @return `metrics_fit_calib` returns an [mgcv::gam()] model fit, otherwise a number. |
|
| 240 |
#' |
|
| 241 |
#' @examples |
|
| 242 |
#' # Scores |
|
| 243 |
#' score_quadratic(y = c(1.34, 2.8), y_hat = c(1.34, 2.8)) # must be 0 |
|
| 244 |
#' score_quadratic(y = 0.5, 0) # must be 0.5**2 = 0.25 |
|
| 245 |
#' |
|
| 246 |
#' score_log_loss(y = c(0, 1), y_hat = c(0.01, 0.9)) # must be close to 0 |
|
| 247 |
#' score_log_loss(y = 0, y_hat = 0) # undefined |
|
| 248 |
#' |
|
| 249 |
#' check_score_fun(score_quadratic) # passes without errors |
|
| 250 |
#' |
|
| 251 |
#' # Metrics based on `lm` model |
|
| 252 |
#' mod <- lm(hp ~ ., data = mtcars) |
|
| 253 |
#' truth <- mtcars$hp |
|
| 254 |
#' pred <- predict(mod) |
|
| 255 |
#' |
|
| 256 |
#' # calibration fit and calibrated predictions |
|
| 257 |
#' calib_mod <- metrics_fit_calib(y = truth, y_hat = pred) |
|
| 258 |
#' calib_pred <- predict(calib_mod) |
|
| 259 |
#' |
|
| 260 |
#' metrics_unc(score_fun = "score_quadratic", y = truth) |
|
| 261 |
#' metrics_R2(score_fun = "score_quadratic", y = truth, y_hat = pred) |
|
| 262 |
#' metrics_DI(score_fun = "score_quadratic", y = truth, y_hat_calib = calib_pred) |
|
| 263 |
#' metrics_MI(score_fun = "score_quadratic", y = truth, y_hat = pred, y_hat_calib = calib_pred) |
|
| 264 |
#' # Note that R^2 = DI - MI |
|
| 265 |
#' metrics_r2(y = truth, y_hat = pred, y_hat_calib = calib_pred) |
|
| 266 |
#' |
|
| 267 |
#' # Metrics based on `glm` model (logistic regression) |
|
| 268 |
#' # Note the correct setting of levels |
|
| 269 |
#' mod <- glm(factor(vs, levels = c("0", "1")) ~ hp + mpg, data = mtcars, family = "binomial")
|
|
| 270 |
#' truth_fct <- factor(mtcars$vs, levels = c("0", "1"))
|
|
| 271 |
#' truth_num <- mtcars$vs |
|
| 272 |
#' pred <- predict(mod, type = "response") # type = "response" returns probabilities |
|
| 273 |
#' |
|
| 274 |
#' # calibration fit and calibrated predictions |
|
| 275 |
#' calib_mod <- metrics_fit_calib(y = truth_fct, y_hat = pred) |
|
| 276 |
#' calib_pred <- predict(calib_mod, type = "response") # type = "response" returns probabilities |
|
| 277 |
#' |
|
| 278 |
#' metrics_unc(score_fun = "score_quadratic", y = truth_num) |
|
| 279 |
#' metrics_R2(score_fun = "score_quadratic", y = truth_num, y_hat = pred) |
|
| 280 |
#' metrics_DI(score_fun = "score_quadratic", y = truth_num, y_hat_calib = calib_pred) |
|
| 281 |
#' metrics_MI(score_fun = "score_quadratic", y = truth_num, y_hat = pred, y_hat_calib = calib_pred) |
|
| 282 |
#' # Note that R^2 = DI - MI |
|
| 283 |
#' metrics_r2(y = truth_num, y_hat = pred, y_hat_calib = calib_pred) |
|
| 284 |
#' |
|
| 285 |
NULL |
|
| 286 | ||
| 287 |
remove_missing <- function(...) {
|
|
| 288 | 4x |
idx <- complete.cases(...) |
| 289 | 4x |
lapply(list(...), \(x) x[idx]) |
| 290 |
} |
|
| 291 | ||
| 292 |
#' @describeIn metrics Binary log loss score |
|
| 293 |
#' @export |
|
| 294 |
score_log_loss <- function(y, y_hat, na_rm = FALSE) {
|
|
| 295 | ! |
checkmate::assert_numeric(y) |
| 296 | ! |
checkmate::assert( |
| 297 | ! |
checkmate::check_numeric(y_hat, len = length(y)), |
| 298 | ! |
checkmate::check_numeric(y_hat, len = 1) |
| 299 |
) |
|
| 300 | ! |
if (na_rm) {
|
| 301 | ! |
rm <- remove_missing(y = y, y_hat = y_hat) |
| 302 | ! |
y <- rm[["y"]] |
| 303 | ! |
y_hat <- rm[["y_hat"]] |
| 304 |
} |
|
| 305 | ! |
-mean(y * log(y_hat) + (1 - y) * log(1 - y_hat)) |
| 306 |
} |
|
| 307 | ||
| 308 |
#' @describeIn metrics Quadratic score |
|
| 309 |
#' @export |
|
| 310 |
score_quadratic <- function(y, y_hat, na_rm = FALSE) {
|
|
| 311 | 292x |
checkmate::assert_numeric(y) |
| 312 | 292x |
checkmate::assert( |
| 313 | 292x |
checkmate::check_numeric(y_hat, len = length(y)), |
| 314 | 292x |
checkmate::check_numeric(y_hat, len = 1) |
| 315 |
) |
|
| 316 | 292x |
if (na_rm) {
|
| 317 | ! |
rm <- remove_missing(y = y, y_hat = y_hat) |
| 318 | ! |
y <- rm[["y"]] |
| 319 | ! |
y_hat <- rm[["y_hat"]] |
| 320 |
} |
|
| 321 | 292x |
mean((y - y_hat)**2) |
| 322 |
} |
|
| 323 | ||
| 324 |
#' @describeIn metrics Utility function for checking the properties of a user-defined `score_fun`. |
|
| 325 |
#' @export |
|
| 326 |
check_score_fun <- function(score_fun) {
|
|
| 327 | 95x |
if (is.character(score_fun)) {
|
| 328 | 95x |
checkmate::assert_function(eval(str2lang(score_fun)), args = c("y", "y_hat", "na_rm"))
|
| 329 | ! |
} else if (is.function(score_fun)) {
|
| 330 | ! |
checkmate::assert_function(score_fun, args = c("y", "y_hat", "na_rm"))
|
| 331 |
} else {
|
|
| 332 | ! |
stop("`score_fun` must be a function.")
|
| 333 |
} |
|
| 334 | 95x |
out <- do.call(score_fun, list(y = c(0.5, 0.6), y_hat = c(0.5, 0.55))) |
| 335 | 95x |
if (!checkmate::test_number(out, na.ok = TRUE)) {
|
| 336 | ! |
stop("The return value of `score_fun` must be a number")
|
| 337 |
} |
|
| 338 |
} |
|
| 339 | ||
| 340 | ||
| 341 |
#' @describeIn metrics Uncertainty |
|
| 342 |
#' @export |
|
| 343 |
metrics_unc <- function(score_fun, y, na_rm = FALSE) {
|
|
| 344 | ! |
check_score_fun(score_fun) |
| 345 | ! |
if (na_rm) {
|
| 346 | ! |
rm <- remove_missing(y = y) |
| 347 | ! |
y <- rm[["y"]] |
| 348 |
} |
|
| 349 | ! |
do.call(score_fun, list(y = y, y_hat = rep_len(mean(y), length(y)))) |
| 350 |
} |
|
| 351 | ||
| 352 |
#' @describeIn metrics R^2 metric |
|
| 353 |
#' @export |
|
| 354 |
metrics_R2 <- function(score_fun, y, y_hat, na_rm = FALSE) {
|
|
| 355 | 88x |
check_score_fun(score_fun) |
| 356 | 88x |
if (na_rm) {
|
| 357 | 1x |
rm <- remove_missing(y = y, y_hat = y_hat) |
| 358 | 1x |
y <- rm[["y"]] |
| 359 | 1x |
y_hat <- rm[["y_hat"]] |
| 360 |
} |
|
| 361 | 88x |
1 - |
| 362 | 88x |
do.call(score_fun, list(y = y, y_hat = y_hat)) / |
| 363 | 88x |
do.call(score_fun, list(y = y, y_hat = rep_len(mean(y), length(y)))) |
| 364 |
} |
|
| 365 | ||
| 366 |
#' @describeIn metrics Fit calibration curve using [mgcv::gam()]. |
|
| 367 |
#' Note that NAs are always dropped. |
|
| 368 |
#' @export |
|
| 369 |
metrics_fit_calib <- function(y, y_hat, rev_fct = FALSE) {
|
|
| 370 | 3x |
requireNamespace("mgcv")
|
| 371 | 3x |
s <- mgcv::s |
| 372 | 3x |
if (is.factor(y)) {
|
| 373 | ! |
fam <- binomial() |
| 374 | ! |
if(rev_fct) y <- factor(y, levels=rev(levels(y))) |
| 375 |
} else {
|
|
| 376 | 3x |
fam <- gaussian() |
| 377 |
} |
|
| 378 | 3x |
tryCatch( |
| 379 | 3x |
mgcv::gam(y ~ s(y_hat, k = -1), family = fam, na.action = "na.omit"), |
| 380 | 3x |
error = \(e) tryCatch( |
| 381 | 3x |
mgcv::gam(y ~ s(y_hat, k = 3), family = fam, na.action = "na.omit"), |
| 382 | 3x |
error = \(e) mgcv::gam(y ~ y_hat, family = fam, na.action = "na.omit") |
| 383 |
) |
|
| 384 |
) |
|
| 385 |
} |
|
| 386 | ||
| 387 |
#' @describeIn metrics Discrimination index |
|
| 388 |
#' @export |
|
| 389 |
metrics_DI <- function(score_fun, y, y_hat_calib, na_rm = FALSE) {
|
|
| 390 | 3x |
check_score_fun(score_fun) |
| 391 | 3x |
if (na_rm) {
|
| 392 | 1x |
rm <- remove_missing(y = y, y_hat_calib = y_hat_calib) |
| 393 | 1x |
y <- rm[["y"]] |
| 394 | 1x |
y_hat_calib <- rm[["y_hat_calib"]] |
| 395 |
} |
|
| 396 |
( |
|
| 397 | 3x |
do.call(score_fun, list(y = y, y_hat = rep_len(mean(y), length(y)))) - |
| 398 | 3x |
do.call(score_fun, list(y = y, y_hat = y_hat_calib)) |
| 399 |
) / |
|
| 400 | 3x |
do.call(score_fun, list(y = y, y_hat = rep_len(mean(y), length(y)))) |
| 401 |
} |
|
| 402 | ||
| 403 |
#' @describeIn metrics Miscalibration index |
|
| 404 |
#' @export |
|
| 405 |
metrics_MI <- function(score_fun, y, y_hat, y_hat_calib, na_rm = FALSE) {
|
|
| 406 | 4x |
check_score_fun(score_fun) |
| 407 | 4x |
if (na_rm) {
|
| 408 | 1x |
rm <- remove_missing(y = y, y_hat = y_hat, y_hat_calib = y_hat_calib) |
| 409 | 1x |
y <- rm[["y"]] |
| 410 | 1x |
y_hat <- rm[["y_hat"]] |
| 411 | 1x |
y_hat_calib <- rm[["y_hat_calib"]] |
| 412 |
} |
|
| 413 |
( |
|
| 414 | 4x |
do.call(score_fun, list(y = y, y_hat = y_hat)) - |
| 415 | 4x |
do.call(score_fun, list(y = y, y_hat = y_hat_calib)) |
| 416 |
) / |
|
| 417 | 4x |
do.call(score_fun, list(y = y, y_hat = rep_len(mean(y), length(y)))) |
| 418 |
} |
|
| 419 | ||
| 420 | ||
| 421 |
#' @describeIn metrics r^2 metric based on slope of `lm` |
|
| 422 |
#' @export |
|
| 423 |
metrics_r2 <- function(y, y_hat, y_hat_calib, na_rm = FALSE) {
|
|
| 424 | 3x |
if (na_rm) {
|
| 425 | 1x |
rm <- remove_missing(y = y, y_hat = y_hat, y_hat_calib = y_hat_calib) |
| 426 | 1x |
y <- rm[["y"]] |
| 427 | 1x |
y_hat <- rm[["y_hat"]] |
| 428 | 1x |
y_hat_calib <- rm[["y_hat_calib"]] |
| 429 | 2x |
} else if (anyNA(y) || anyNA(y_hat) || anyNA(y_hat_calib)) {
|
| 430 | 1x |
return(NA) |
| 431 |
} |
|
| 432 | 2x |
lm_mod <- lm(y_hat_calib ~ y_hat) |
| 433 | 2x |
res <- (coef(lm_mod)[2] * sd(y_hat) / sd(y))**2 |
| 434 | 2x |
if (is.na(res)) {
|
| 435 | ! |
res <- 0 |
| 436 |
} |
|
| 437 | 2x |
return(unname(res)) |
| 438 |
} |
| 1 |
# colors ------- |
|
| 2 | ||
| 3 |
# function for generating colours |
|
| 4 |
ms_color <- function(n, hue_coloring = FALSE) {
|
|
| 5 | ! |
if (n < 7 && !hue_coloring) {
|
| 6 | ! |
c("#0a0a0a", "#14a3a8", "#e3211d", "#b15829", "#6a3d9a", "#34a02b")[1:n]
|
| 7 |
} else {
|
|
| 8 | ! |
hcl(h = seq(15, 375, length = n + 1), l = 35, c = 85)[1:n] |
| 9 |
} |
|
| 10 |
} |
|
| 11 | ||
| 12 | ||
| 13 |
# facets specification ------- |
|
| 14 | ||
| 15 |
#' Instructions for facet vizualisations |
|
| 16 |
#' |
|
| 17 |
#' @param labels (`NULL`) or named character vector with variable labels. |
|
| 18 |
#' @param ncol (`NULL`) or number of columns in the facet. |
|
| 19 |
#' @param sort One of "alphabetical", "importance", or "range" - sorting of the facets. |
|
| 20 |
#' @param top_k (`NULL`) or number of most important features to show. |
|
| 21 |
#' @param subset (`NULL`) or a vector of variables to show. |
|
| 22 |
#' @param scales One of "free", "free_x", or "free_y" - axis scales of the graphs. |
|
| 23 |
#' |
|
| 24 |
#' @return List of class `facet_specification`. |
|
| 25 |
#' @export |
|
| 26 |
#' |
|
| 27 |
#' @examples |
|
| 28 |
#' \dontrun{
|
|
| 29 |
#' g_ice( |
|
| 30 |
#' sculpture, |
|
| 31 |
#' facet_spec = facet_specification( |
|
| 32 |
#' ncol = 3, # display 3 columns |
|
| 33 |
#' sort = "importance" # sort by importance |
|
| 34 |
#' ) |
|
| 35 |
#' ) |
|
| 36 |
#' } |
|
| 37 |
#' |
|
| 38 |
facet_specification <- function(labels = NULL, |
|
| 39 |
ncol = NULL, |
|
| 40 |
sort = "alphabetical", |
|
| 41 |
top_k = NULL, |
|
| 42 |
subset = NULL, |
|
| 43 |
scales = "free_x") {
|
|
| 44 | ! |
checkmate::assert( |
| 45 | ! |
checkmate::check_character(labels, any.missing = FALSE, names = "named"), |
| 46 | ! |
checkmate::check_null(labels) |
| 47 |
) |
|
| 48 | ||
| 49 | ! |
checkmate::assert( |
| 50 | ! |
checkmate::check_integerish(ncol, any.missing = FALSE, len = 1, lower = 1), |
| 51 | ! |
checkmate::check_null(ncol) |
| 52 |
) |
|
| 53 | ||
| 54 | ! |
checkmate::assert_character(sort, any.missing = FALSE, len = 1) |
| 55 | ! |
checkmate::assert_subset(sort, c("alphabetical", "importance", "range"))
|
| 56 | ||
| 57 | ! |
checkmate::assert( |
| 58 | ! |
checkmate::check_number(top_k, lower = 1), |
| 59 | ! |
checkmate::check_null(top_k) |
| 60 |
) |
|
| 61 | ||
| 62 | ! |
checkmate::assert( |
| 63 | ! |
checkmate::check_character(subset), |
| 64 | ! |
checkmate::check_null(subset) |
| 65 |
) |
|
| 66 | ||
| 67 | ! |
checkmate::assert_character(scales, len = 1, any.missing = FALSE) |
| 68 | ! |
checkmate::assert_subset(scales, c("free", "free_y", "free_x"))
|
| 69 | ||
| 70 | ! |
if (!is.null(top_k) & !is.null(subset)) {
|
| 71 | ! |
stop("Please use either `top_k` or `subset`, but not both together.")
|
| 72 |
} |
|
| 73 | ||
| 74 | ! |
out <- list( |
| 75 | ! |
labels = labels, ncol = ncol, sort = sort, top_k = top_k, |
| 76 | ! |
subset = subset, scales = scales |
| 77 |
) |
|
| 78 | ! |
class(out) <- "facet_specification" |
| 79 | ||
| 80 | ! |
return(out) |
| 81 |
} |
|
| 82 | ||
| 83 | ||
| 84 |
resolve_facet_specification <- function(obj, fs) {
|
|
| 85 |
# checks |
|
| 86 | ! |
checkmate::assert_class(obj, "sculpture") |
| 87 | ! |
checkmate::assert_class(fs, "facet_specification") |
| 88 | ||
| 89 |
# resolve labels |
|
| 90 | ! |
if (is.null(fs$labels)) {
|
| 91 | ! |
fs$labels <- structure(names(obj), names = names(obj)) |
| 92 |
} |
|
| 93 | ! |
checkmate::assert_character( |
| 94 | ! |
fs$labels, |
| 95 | ! |
names = "named", len = length(obj), any.missing = FALSE, |
| 96 | ! |
.var.name = "facet_specification$labels" |
| 97 |
) |
|
| 98 | ||
| 99 |
# resolve facet sorting |
|
| 100 | ! |
feat_ordered <- resolve_facet_sort(obj = obj, facet_sort = fs$sort) |
| 101 | ||
| 102 |
# resolve facet subset and top_k |
|
| 103 | ! |
obj <- resolve_facet_subset_topk( |
| 104 | ! |
obj = obj, facet_subset = fs$subset, facet_top_k = fs$top_k, |
| 105 | ! |
feat_ordered = feat_ordered |
| 106 |
) |
|
| 107 | ||
| 108 |
# resolve facet ncol |
|
| 109 | ! |
idx_c <- !vapply(obj, "[[", logical(1), "is_discrete") |
| 110 | ! |
facet_ncol_res <- resolve_facet_ncol(idx_c = idx_c, facet_ncol = fs$ncol) |
| 111 | ||
| 112 | ! |
return( |
| 113 | ! |
list( |
| 114 | ! |
object = obj, |
| 115 | ! |
labels = fs$labels, |
| 116 | ! |
ncol_c = facet_ncol_res$ncol_c, |
| 117 | ! |
ncol_d = facet_ncol_res$ncol_d, |
| 118 | ! |
scales = fs$scales |
| 119 |
) |
|
| 120 |
) |
|
| 121 |
} |
|
| 122 | ||
| 123 | ||
| 124 |
resolve_facet_sort <- function(obj, facet_sort) {
|
|
| 125 | ! |
if (facet_sort == "alphabetical") {
|
| 126 | ! |
feat_ordered <- sort(names(obj)) |
| 127 | ! |
} else if (facet_sort == "importance") {
|
| 128 | ! |
vimp <- attr(obj, "var_imp") |
| 129 | ! |
feat_ordered <- levels(vimp$feature) |
| 130 | ! |
} else if (facet_sort == "range") {
|
| 131 | ! |
rng <- attr(obj, "range") |
| 132 | ! |
feat_ordered <- levels(rng$feature) |
| 133 |
} else {
|
|
| 134 | ! |
stop("Unknown sorting")
|
| 135 |
} |
|
| 136 | ! |
return(feat_ordered) |
| 137 |
} |
|
| 138 | ||
| 139 |
resolve_facet_subset_topk <- function(obj, facet_subset, facet_top_k, feat_ordered) {
|
|
| 140 | ! |
stopifnot( |
| 141 | ! |
is.null(facet_subset) || is.null(facet_top_k), |
| 142 | ! |
all(feat_ordered %in% names(obj)) |
| 143 |
) |
|
| 144 | ! |
if (!is.null(facet_top_k)) {
|
| 145 | ! |
vars <- feat_ordered[1:min(facet_top_k, length(feat_ordered))] |
| 146 | ! |
} else if (!is.null(facet_subset)) {
|
| 147 | ! |
vars <- feat_ordered[feat_ordered %in% facet_subset] |
| 148 |
} else {
|
|
| 149 | ! |
vars <- feat_ordered |
| 150 |
} |
|
| 151 | ! |
new_attrs <- attributes(obj) |
| 152 | ! |
obj <- obj[vars] |
| 153 | ! |
new_attrs$names <- names(obj) |
| 154 | ! |
attributes(obj) <- new_attrs |
| 155 | ! |
return(obj) |
| 156 |
} |
|
| 157 | ||
| 158 |
resolve_facet_ncol <- function(idx_c, facet_ncol) {
|
|
| 159 | ! |
n_feat_c <- sum(idx_c) |
| 160 | ! |
n_feat_d <- sum(!idx_c) |
| 161 | ! |
if (is.null(facet_ncol)) {
|
| 162 | ! |
facet_ncol_c <- min(c(n_feat_c, 4)) |
| 163 | ! |
facet_ncol_d <- min(c(n_feat_d, 4)) |
| 164 |
} else {
|
|
| 165 | ! |
facet_ncol_c <- min(c(n_feat_c, facet_ncol)) |
| 166 | ! |
facet_ncol_d <- min(c(n_feat_d, facet_ncol)) |
| 167 |
} |
|
| 168 | ! |
return(list(ncol_c = facet_ncol_c, ncol_d = facet_ncol_d)) |
| 169 |
} |
|
| 170 | ||
| 171 | ||
| 172 | ||
| 173 |
resolve_y_limits <- function(dat_c, dat_d, facet_scales) {
|
|
| 174 | ! |
if (facet_scales %in% c("free", "free_y")) {
|
| 175 | ! |
c(NA_real_, NA_real_) |
| 176 |
} else {
|
|
| 177 | ! |
c( |
| 178 | ! |
floor(min(c(dat_c[["y"]], dat_d[["y"]])) * 10) / 10, |
| 179 | ! |
ceiling(max(c(dat_c[["y"]], dat_d[["y"]])) * 10) / 10 |
| 180 |
) |
|
| 181 |
} |
|
| 182 |
} |
|
| 183 | ||
| 184 | ||
| 185 | ||
| 186 |
# missings specification ------- |
|
| 187 | ||
| 188 | ||
| 189 |
#' Instructions for missings vizualisations |
|
| 190 |
#' |
|
| 191 |
#' @param vline (`logical`) Should the vertical line be shown? Defaults to `FALSE`. |
|
| 192 |
#' @param hline (`logical`) Should the horizontal line be shown? Defaults to `FALSE`. |
|
| 193 |
#' @param values (`NULL`) or single value or a named vector. |
|
| 194 |
#' Specifies the value(-s) that stand for the missing values. |
|
| 195 |
#' If `NULL`, then no missing value handling is carried out. |
|
| 196 |
#' If single value, then it is assumed that this value is used for flagging missing values across |
|
| 197 |
#' all continuous variables. |
|
| 198 |
#' If named vector, then the names are used to refer to continuous variables and the values for |
|
| 199 |
#' flagging missing values in that variable. |
|
| 200 |
#' @param drop_from_plot (`logical`) Should the missing values be dropped from plot? |
|
| 201 |
#' Defaults to `FALSE`. |
|
| 202 |
#' |
|
| 203 |
#' @return List of class `missings_specification`. |
|
| 204 |
#' @export |
|
| 205 |
#' |
|
| 206 |
#' @examples |
|
| 207 |
#' \dontrun{
|
|
| 208 |
#' g_ice( |
|
| 209 |
#' sculpture, |
|
| 210 |
#' missings_spec = missings_specification( |
|
| 211 |
#' vline = TRUE, # show vertical line |
|
| 212 |
#' values = -1 # NAs in all continuous variables displayed as -1 |
|
| 213 |
#' ) |
|
| 214 |
#' ) |
|
| 215 |
#' } |
|
| 216 |
#' |
|
| 217 |
missings_specification <- function(vline = FALSE, hline = FALSE, values = NULL, |
|
| 218 |
drop_from_plot = FALSE) {
|
|
| 219 | ! |
checkmate::assert_flag(vline) |
| 220 | ! |
checkmate::assert_flag(hline) |
| 221 | ! |
checkmate::assert_flag(drop_from_plot) |
| 222 | ! |
checkmate::assert( |
| 223 | ! |
checkmate::check_null(values), |
| 224 | ! |
checkmate::check_atomic(values, any.missing = FALSE, len = 1), |
| 225 | ! |
checkmate::check_atomic(values, any.missing = FALSE, min.len = 2, names = "named") |
| 226 |
) |
|
| 227 | ! |
if (any(c(vline, hline)) && is.null(values)) {
|
| 228 | ! |
stop("Specified to show lines, but no missing values provided.")
|
| 229 |
} |
|
| 230 | ! |
if (drop_from_plot && is.null(values)) {
|
| 231 | ! |
stop("Specified to drop missings from plot area, but no missing values provided.")
|
| 232 |
} |
|
| 233 | ! |
if (drop_from_plot && vline) {
|
| 234 | ! |
stop("Please use either `drop_from_plot` or `vline`, but not both.")
|
| 235 |
} |
|
| 236 | ! |
out <- list(vline = vline, hline = hline, values = values, drop_from_plot = drop_from_plot) |
| 237 | ! |
class(out) <- "missings_specification" |
| 238 | ! |
return(out) |
| 239 |
} |
|
| 240 | ||
| 241 |
resolve_missings_specification <- function(dat_c, ms, missings) {
|
|
| 242 | ! |
if (is.null(missings)) {
|
| 243 | ! |
return(list(dat_c = dat_c, missings = missings)) |
| 244 |
} |
|
| 245 | ||
| 246 |
# https://cran.r-project.org/web/packages/data.table/vignettes/datatable-importing.html#globals |
|
| 247 | ! |
line_id <- feature <- ..cols <- NULL # due to NSE notes in R CMD check |
| 248 | ||
| 249 |
# add PDP_centered column |
|
| 250 | ! |
if ("line_id" %in% colnames(dat_c)) {
|
| 251 | ! |
dat_c_pdp <- dat_c[line_id == "pdp"] |
| 252 |
} else {
|
|
| 253 | ! |
dat_c_pdp <- dat_c |
| 254 |
} |
|
| 255 | ! |
missings_new <- missings[dat_c_pdp, nomatch = NULL, on = c("feature", "x")]
|
| 256 | ! |
cols <- c("feature", "x", "y", `if`("Model" %in% colnames(missings_new), "Model"))
|
| 257 | ! |
missings_new <- missings_new[, ..cols] |
| 258 | ! |
missings_new[, feature := factor(feature, levels = levels(dat_c$feature))] |
| 259 | ||
| 260 |
# remove missing observations if requested |
|
| 261 | ! |
if (ms$drop_from_plot) {
|
| 262 | ! |
dat_c <- dat_c[!missings, on = c("feature", "x")]
|
| 263 |
} |
|
| 264 | ! |
return(list(dat_c = dat_c, missings = missings_new)) |
| 265 |
} |
|
| 266 | ||
| 267 | ||
| 268 | ||
| 269 |
# plots ------- |
|
| 270 | ||
| 271 | ||
| 272 |
#' Plot variable importances and cumulative approximation of R^2 |
|
| 273 |
#' |
|
| 274 |
#' @param object (`sculpture`) |
|
| 275 |
#' @param feat_labels (`NULL`) or named character vector providing the variable labels. |
|
| 276 |
#' @param textsize Size of text. |
|
| 277 |
#' @param top_k (`NULL`) or number to show only the most `k` important variables. |
|
| 278 |
#' @param pdp_plot_sample (`logical`) Sample PDP for faster ploting? Defaults to `TRUE`. |
|
| 279 |
#' @param show_pdp_plot (`logical`) Show plot with PDP ranges? Defaults to `TRUE`. |
|
| 280 |
#' @param var_imp_type (`character`) One of `c("normalized", "absolute", "ice", "ice_orig_mod")`.
|
|
| 281 |
#' Defaults to "normalized". "ice" is only valid for a rough sculpture. |
|
| 282 |
#' @param logodds_to_prob (`logical`) Only valid for binary response and sculptures built on |
|
| 283 |
#' the log-odds scale. Defaults to `FALSE` (i.e. no effect). |
|
| 284 |
#' If `TRUE`, then the y-values are transformed through inverse logit function 1 / (1 + exp(-x)). |
|
| 285 |
#' @param plot_ratios (`numeric`) Used in the layout matrix of `gridExtra::arrangeGrob()`. |
|
| 286 |
#' If `show_pdp_plot`, then the default is `c(3,2,2)`, making the first plot 3 units wide and |
|
| 287 |
#' the other two plots 2 units wide. |
|
| 288 |
#' If `!show_pdp_plot`, then the default is `c(3,2)`, making the first plot 3 units wide and |
|
| 289 |
#' the second plot 2 units wide. |
|
| 290 |
#' Note that the length needs to be 3 if `show_pdp_plot` or 2 if `!show_pdp_plot`. |
|
| 291 |
#' |
|
| 292 |
#' @return `grob`. Use `grid::grid.draw` to plot the output |
|
| 293 |
#' (`grid::grid.newpage` resets the plotting area). |
|
| 294 |
#' |
|
| 295 |
#' @export |
|
| 296 |
#' |
|
| 297 |
#' @examples |
|
| 298 |
#' df <- mtcars |
|
| 299 |
#' df$vs <- as.factor(df$vs) |
|
| 300 |
#' model <- rpart::rpart( |
|
| 301 |
#' hp ~ mpg + carb + vs, |
|
| 302 |
#' data = df, |
|
| 303 |
#' control = rpart::rpart.control(minsplit = 10) |
|
| 304 |
#' ) |
|
| 305 |
#' model_predict <- function(x) predict(model, newdata = x) |
|
| 306 |
#' covariates <- c("mpg", "carb", "vs")
|
|
| 307 |
#' pm <- sample_marginals(df[covariates], n = 50, seed = 5) |
|
| 308 |
#' |
|
| 309 |
#' rs <- sculpt_rough( |
|
| 310 |
#' dat = pm, |
|
| 311 |
#' model_predict_fun = model_predict, |
|
| 312 |
#' n_ice = 10, |
|
| 313 |
#' seed = 1, |
|
| 314 |
#' verbose = 0 |
|
| 315 |
#' ) |
|
| 316 |
#' |
|
| 317 |
#' # optionally define labels |
|
| 318 |
#' labels <- structure( |
|
| 319 |
#' toupper(covariates), # labels |
|
| 320 |
#' names = covariates # current (old) names |
|
| 321 |
#' ) |
|
| 322 |
#' vi <- g_var_imp(rs, feat_labels = labels) |
|
| 323 |
#' grid::grid.draw(vi) |
|
| 324 |
#' |
|
| 325 |
g_var_imp <- function(object, feat_labels = NULL, textsize = 16, top_k = NULL, |
|
| 326 |
pdp_plot_sample = TRUE, |
|
| 327 |
show_pdp_plot = TRUE, |
|
| 328 |
var_imp_type = "normalized", |
|
| 329 |
logodds_to_prob = FALSE, |
|
| 330 |
plot_ratios = `if`(show_pdp_plot, c(3, 2, 2), c(3, 2))) {
|
|
| 331 | 1x |
checkmate::assert_class(object, "sculpture") |
| 332 | 1x |
checkmate::assert_integerish(textsize, len = 1, any.missing = FALSE) |
| 333 | 1x |
checkmate::assert_flag(pdp_plot_sample) |
| 334 | 1x |
checkmate::assert_flag(show_pdp_plot) |
| 335 | 1x |
checkmate::assert_choice(var_imp_type, choices = c("normalized", "absolute", "ice", "ice_orig_mod"))
|
| 336 | 1x |
checkmate::assert_flag(logodds_to_prob) |
| 337 | ||
| 338 | 1x |
if (show_pdp_plot) {
|
| 339 | ! |
checkmate::assert_integerish(plot_ratios, lower = 1, len = 3, any.missing = FALSE) |
| 340 |
} else {
|
|
| 341 | 1x |
checkmate::assert_integerish(plot_ratios, lower = 1, len = 2, any.missing = FALSE) |
| 342 |
} |
|
| 343 | ||
| 344 | 1x |
if (is.null(feat_labels)) {
|
| 345 | 1x |
feat_labels <- structure(names(object), names = names(object)) |
| 346 |
} |
|
| 347 | 1x |
checkmate::assert_character(feat_labels, names = "named", len = length(object)) |
| 348 | ||
| 349 | 1x |
checkmate::assert( |
| 350 | 1x |
checkmate::check_null(top_k), |
| 351 | 1x |
checkmate::check_integerish(top_k, any.missing = FALSE, len = 1, lower = 1) |
| 352 |
) |
|
| 353 | ||
| 354 |
# https://cran.r-project.org/web/packages/data.table/vignettes/datatable-importing.html#globals |
|
| 355 | 1x |
feature <- . <- pdp_c <- ice_centered <- line_id <- var_y <- |
| 356 | 1x |
ice <- y <- |
| 357 | 1x |
NULL # due to NSE notes in R CMD check |
| 358 | ||
| 359 | 1x |
if (logodds_to_prob) {
|
| 360 |
# evaluate the sculpture |
|
| 361 | 1x |
es <- eval_sculpture( |
| 362 | 1x |
sculpture = object, |
| 363 | 1x |
data = as.data.frame(as.data.table(lapply(object, "[[", "x"))) |
| 364 |
) |
|
| 365 | 1x |
dt <- es$pdp |
| 366 | ||
| 367 |
# convert log-odds scale to probability scale and center back to 0 |
|
| 368 | 1x |
dt[, pdp_c := inv.logit(pdp_c) - 0.5] |
| 369 | ||
| 370 |
# get importance |
|
| 371 | 1x |
dat_var <- calc_dir_var_imp_pdp(dt) |
| 372 | 1x |
feat_order <- levels(dat_var$feature) |
| 373 | ||
| 374 |
# get cumul. R2 |
|
| 375 | 1x |
dat_R2_cumul <- calc_cumul_R2_pdp( |
| 376 | 1x |
dt = dt, |
| 377 | 1x |
feat_order = feat_order, |
| 378 | 1x |
model_predictions = inv.logit(es$prediction$pred), |
| 379 | 1x |
model_offset = inv.logit(es$offset) |
| 380 |
) |
|
| 381 |
} else {
|
|
| 382 |
# get importance |
|
| 383 | ! |
dat_var <- attr(object, "var_imp") |
| 384 | ! |
feat_order <- levels(dat_var$feature) |
| 385 | ||
| 386 |
# get cumul. R2 |
|
| 387 | ! |
dat_R2_cumul <- attr(object, "cumul_R2") |
| 388 | ||
| 389 |
# get PDPs and predictions |
|
| 390 | ! |
if (show_pdp_plot) {
|
| 391 | ! |
eg <- eval_sculpture( |
| 392 | ! |
sculpture = object, |
| 393 | ! |
data = as.data.frame(as.data.table(lapply(object, "[[", "x"))) |
| 394 |
) |
|
| 395 | ! |
dt <- eg$pdp |
| 396 |
} |
|
| 397 |
} |
|
| 398 | ||
| 399 |
# subset top_k if requested |
|
| 400 | 1x |
if (!is.null(top_k)) {
|
| 401 | ! |
top_k <- min(top_k, nrow(dat_var)) |
| 402 | ! |
feat_order <- feat_order[1:top_k] |
| 403 | ! |
dat_var <- dat_var[feature %in% feat_order] |
| 404 | ! |
dat_R2_cumul <- dat_R2_cumul[feature %in% feat_order] |
| 405 | ! |
if (show_pdp_plot) {
|
| 406 | ! |
object <- object[feat_order] |
| 407 | ! |
class(object) <- "sculpture" |
| 408 |
} |
|
| 409 |
} |
|
| 410 | ||
| 411 |
# g1 - PDP values |
|
| 412 | 1x |
if (show_pdp_plot) {
|
| 413 |
# check centering |
|
| 414 | ! |
check_dt <- dt[, .(mean_pdp_c = mean(pdp_c)), .(feature)] |
| 415 | ! |
if (abs(mean(check_dt$mean_pdp_c)) > 1e-1) {
|
| 416 | ! |
stop(paste( |
| 417 | ! |
"PDPs not centered, mean relative difference of", |
| 418 | ! |
abs(mean(check_dt$mean_pdp_c)) |
| 419 |
)) |
|
| 420 |
} |
|
| 421 |
# draw PDP plot |
|
| 422 | ! |
dt$feature <- factor(dt$feature, levels = feat_order) |
| 423 | ! |
g1 <- g_pdp(dt = dt, pdp_plot_sample = pdp_plot_sample, feat_labels = feat_labels) |
| 424 |
} else {
|
|
| 425 | 1x |
g1 <- NULL |
| 426 |
} |
|
| 427 | ||
| 428 |
# g2 - variable importance |
|
| 429 | 1x |
if (var_imp_type == "normalized") {
|
| 430 | ! |
g2 <- g_imp_norm(dat_var = dat_var, show_pdp_plot = show_pdp_plot, textsize = textsize) |
| 431 | 1x |
} else if (var_imp_type == "absolute") {
|
| 432 | ! |
g2 <- g_imp_abs(dat_var = dat_var, show_pdp_plot = show_pdp_plot, textsize = textsize) |
| 433 | 1x |
} else if (var_imp_type == "ice_orig_mod") {
|
| 434 | ! |
if (!inherits(object, "rough")) {
|
| 435 | ! |
stop('`var_imp_type == "ice"` is only valid for a rough sculpture.')
|
| 436 |
} |
|
| 437 |
# get ice curves |
|
| 438 | ! |
dat_var_ice <- rbindlist( |
| 439 | ! |
lapply( |
| 440 | ! |
object, |
| 441 | ! |
function(v) {
|
| 442 | ! |
generate_ice_data( |
| 443 | ! |
predictions = v[["ice"]], |
| 444 | ! |
x = v$x, |
| 445 | ! |
logodds_to_prob = logodds_to_prob |
| 446 | ! |
)[, .(y, line_id)] |
| 447 |
} |
|
| 448 |
), |
|
| 449 | ! |
idcol = "feature" |
| 450 |
) |
|
| 451 |
# convert to factor |
|
| 452 | ! |
dat_var_ice$feature <- factor(dat_var_ice$feature, levels = feat_order) |
| 453 |
# calculate variance |
|
| 454 | ! |
dat_var_ice <- dat_var_ice[, .(var_y = var(y)), by = .(feature, line_id)] |
| 455 |
# calculate mean of variances |
|
| 456 | ! |
vars_mean <- dat_var_ice[, .(mean_var_y = mean(var_y)), by = .(feature)] |
| 457 |
# plot ice variances |
|
| 458 | ! |
g2 <- g_imp_ice(vars = dat_var_ice, vars_mean = vars_mean) |
| 459 | 1x |
} else if (var_imp_type == "ice") {
|
| 460 | 1x |
model_predict_fun <- function(x) {
|
| 461 | 2x |
if(logodds_to_prob) {
|
| 462 | 2x |
p <- predict(object, newdata = x) |
| 463 | 2x |
inv.logit(p) |
| 464 |
} else {
|
|
| 465 | ! |
predict(object, newdata = x) |
| 466 |
} |
|
| 467 |
} |
|
| 468 | ||
| 469 | 1x |
dat_var_ice <- rbindlist( |
| 470 | 1x |
lapply( |
| 471 | 1x |
object, |
| 472 | 1x |
function(v) {
|
| 473 | 2x |
calculate_ice_data( |
| 474 | 2x |
sub = v$subsets, |
| 475 | 2x |
predict_fun = model_predict_fun, |
| 476 | 2x |
x = v$x, |
| 477 | 2x |
x_name = v$x_name, |
| 478 | 2x |
col_order = names(object) |
| 479 | 2x |
)[, .(ice, line_id)] |
| 480 |
} |
|
| 481 |
), |
|
| 482 | 1x |
idcol = "feature" |
| 483 |
) |
|
| 484 | ||
| 485 |
# convert to factor |
|
| 486 | 1x |
dat_var_ice$feature <- factor(dat_var_ice$feature, levels = feat_order) |
| 487 |
# calculate variance |
|
| 488 | 1x |
dat_var_ice <- dat_var_ice[, .(var_y = var(ice)), by = .(feature, line_id)] |
| 489 |
# calculate mean of variances |
|
| 490 | 1x |
vars_mean <- dat_var_ice[, .(mean_var_y = mean(var_y)), by = .(feature)] |
| 491 |
# plot ice variances |
|
| 492 | 1x |
g2 <- g_imp_ice(vars = dat_var_ice, vars_mean = vars_mean) |
| 493 |
} |
|
| 494 | ||
| 495 | 1x |
if (show_pdp_plot) {
|
| 496 | ! |
g2 <- g2 + theme(axis.ticks.y = element_blank(), axis.text.y = element_blank()) |
| 497 |
} else {
|
|
| 498 | 1x |
g2 <- g2 + scale_y_discrete(labels = function(x) feat_labels[x]) |
| 499 |
} |
|
| 500 | ||
| 501 |
# g3 - cumulative R2 |
|
| 502 | 1x |
g3 <- g_cumulR2(dat_R2_cumul = dat_R2_cumul, textsize = textsize) |
| 503 | ||
| 504 |
# combined graph |
|
| 505 | 1x |
if (show_pdp_plot) {
|
| 506 | ! |
g_var_imp <- gridExtra::arrangeGrob( |
| 507 | ! |
g1 + theme( |
| 508 | ! |
plot.margin = unit(c(0.8, 0.5, 0.3, 0.3), "cm"), |
| 509 | ! |
text = element_text(size = textsize) |
| 510 |
), |
|
| 511 | ! |
g2 + theme( |
| 512 | ! |
plot.margin = unit(c(0.8, 0.5, 0.3, 0.3), "cm"), |
| 513 | ! |
text = element_text(size = textsize) |
| 514 |
), |
|
| 515 | ! |
g3 + theme( |
| 516 | ! |
plot.margin = unit(c(0.8, 0.5, 0.05, 0.3), "cm"), |
| 517 | ! |
text = element_text(size = textsize) |
| 518 |
), |
|
| 519 | ! |
layout_matrix = matrix(rep(1:3, plot_ratios), nrow = 1) |
| 520 |
) |
|
| 521 |
} else {
|
|
| 522 | 1x |
g_var_imp <- gridExtra::arrangeGrob( |
| 523 | 1x |
g2 + theme( |
| 524 | 1x |
plot.margin = unit(c(0.8, 0.5, 0.3, 0.3), "cm"), |
| 525 | 1x |
text = element_text(size = textsize) |
| 526 |
), |
|
| 527 | 1x |
g3 + theme( |
| 528 | 1x |
plot.margin = unit(c(0.8, 0.5, 0.05, 0.3), "cm"), |
| 529 | 1x |
text = element_text(size = textsize) |
| 530 |
), |
|
| 531 | 1x |
layout_matrix = matrix(rep(1:2, plot_ratios), nrow = 1) |
| 532 |
) |
|
| 533 |
} |
|
| 534 | ||
| 535 | 1x |
return(g_var_imp) |
| 536 |
} |
|
| 537 | ||
| 538 | ||
| 539 |
#' Plot additivity scatterplot(-s) with R^2 value(-s) |
|
| 540 |
#' |
|
| 541 |
#' @param sp Sculpted predictions. Either as a vector or as a list of those. |
|
| 542 |
#' @param lp Learner predictions. Either as a vector or as a list of those. Same size as `sp`. |
|
| 543 |
#' @param descriptions (Optional) Descriptions of the models to be shown on the plot. |
|
| 544 |
#' Same size as `sp` if `sp` is provided as a list. |
|
| 545 |
#' @param cex `cex` graphical parameter. |
|
| 546 |
#' @param plot_only (`logical`) Return plot only or plot with the R^2 value? |
|
| 547 |
#' Defaults to the first (i.e. `TRUE`). |
|
| 548 |
#' |
|
| 549 |
#' @return If `plot_only`, then a plot. If `!plot_only`, then a plot and a data.frame. |
|
| 550 |
#' @export |
|
| 551 |
#' |
|
| 552 |
#' @examples |
|
| 553 |
#' df <- mtcars |
|
| 554 |
#' df$vs <- as.factor(df$vs) |
|
| 555 |
#' model <- rpart::rpart( |
|
| 556 |
#' hp ~ mpg + carb + vs, |
|
| 557 |
#' data = df, |
|
| 558 |
#' control = rpart::rpart.control(minsplit = 10) |
|
| 559 |
#' ) |
|
| 560 |
#' model_predict <- function(x) predict(model, newdata = x) |
|
| 561 |
#' covariates <- c("mpg", "carb", "vs")
|
|
| 562 |
#' pm <- sample_marginals(df[covariates], n = 50, seed = 5) |
|
| 563 |
#' |
|
| 564 |
#' rs <- sculpt_rough( |
|
| 565 |
#' dat = pm, |
|
| 566 |
#' model_predict_fun = model_predict, |
|
| 567 |
#' n_ice = 10, |
|
| 568 |
#' seed = 1, |
|
| 569 |
#' verbose = 0 |
|
| 570 |
#' ) |
|
| 571 |
#' |
|
| 572 |
#' g_additivity( |
|
| 573 |
#' sp = predict(rs, pm), |
|
| 574 |
#' lp = model_predict(pm), |
|
| 575 |
#' descriptions = "Product Marginal" |
|
| 576 |
#' ) |
|
| 577 |
#' |
|
| 578 |
g_additivity <- function(sp, lp, descriptions = NULL, cex = 4, plot_only = TRUE) {
|
|
| 579 | ! |
checkmate::assert( |
| 580 | ! |
checkmate::check_atomic(sp), |
| 581 | ! |
checkmate::check_list(sp, types = "atomic") |
| 582 |
) |
|
| 583 | ! |
if (!is.list(sp)) {
|
| 584 | ! |
sp <- list(sp) |
| 585 |
} |
|
| 586 | ||
| 587 | ! |
checkmate::assert( |
| 588 | ! |
checkmate::check_atomic(lp, any.missing = FALSE, len = length(sp[[1]])), |
| 589 | ! |
checkmate::check_list(lp, types = "atomic", len = length(sp)) |
| 590 |
) |
|
| 591 | ! |
if (is.list(lp)) {
|
| 592 | ! |
lapply( |
| 593 | ! |
seq_along(lp), |
| 594 | ! |
function(i) checkmate::assert_atomic(lp[[i]], any.missing = FALSE, len = length(sp[[i]])) |
| 595 |
) |
|
| 596 |
} else {
|
|
| 597 | ! |
lp <- rep(list(lp), length(sp)) |
| 598 |
} |
|
| 599 | ||
| 600 | ! |
checkmate::assert( |
| 601 | ! |
checkmate::check_null(descriptions), |
| 602 | ! |
checkmate::check_character(descriptions, any.missing = FALSE, len = length(sp)) |
| 603 |
) |
|
| 604 | ! |
if (is.null(descriptions)) {
|
| 605 | ! |
if (is.null(names(sp))) {
|
| 606 | ! |
descriptions <- paste("Sculpture", seq_along(sp))
|
| 607 |
} else {
|
|
| 608 | ! |
descriptions <- names(sp) |
| 609 |
} |
|
| 610 |
} |
|
| 611 | ||
| 612 | ! |
checkmate::assert_numeric(cex, lower = 0, any.missing = FALSE, len = 1) |
| 613 | ! |
checkmate::assert_logical(plot_only, any.missing = FALSE, len = 1) |
| 614 | ||
| 615 |
# get plot data |
|
| 616 | ! |
pd <- lapply(seq_along(sp), function(i) {
|
| 617 | ! |
data.frame( |
| 618 | ! |
sculpted = sp[[i]], |
| 619 | ! |
learner = lp[[i]], |
| 620 | ! |
Model = descriptions[i] |
| 621 |
) |
|
| 622 |
}) |
|
| 623 | ! |
pd <- do.call("rbind", pd)
|
| 624 | ! |
pd$Model <- factor(pd$Model, levels = descriptions) |
| 625 | ||
| 626 |
# calculate R2 (vs strong learner) |
|
| 627 | ! |
R2_mod_vs_approx <- vapply( |
| 628 | ! |
seq_along(sp), |
| 629 | ! |
function(i) metrics_R2(score_fun = "score_quadratic", y = lp[[i]], y_hat = sp[[i]]), |
| 630 | ! |
FUN.VALUE = numeric(1) |
| 631 |
) |
|
| 632 | ||
| 633 |
# create R2 annotations |
|
| 634 | ! |
annotations <- data.frame( |
| 635 | ! |
R2 = paste0("R^2==", round(R2_mod_vs_approx, 4)),
|
| 636 | ! |
Model = factor(descriptions, levels = descriptions), |
| 637 | ! |
sculpted = min(pd$sculpted) + (max(pd$sculpted) - min(pd$sculpted)) / 10, |
| 638 | ! |
learner = 0.9 * max(pd$learner) |
| 639 |
) |
|
| 640 | ||
| 641 | ! |
g <- ggplot(pd) + |
| 642 | ! |
geom_point(aes(x = .data$sculpted, y = .data$learner), alpha = 0.4, shape = 16) + |
| 643 | ! |
geom_abline(slope = 1, intercept = 0) + |
| 644 | ! |
facet_wrap("Model") +
|
| 645 | ! |
geom_label( |
| 646 | ! |
data = annotations, |
| 647 | ! |
mapping = aes(x = .data$sculpted, y = .data$learner, label = .data$R2), |
| 648 | ! |
hjust = 0, parse = TRUE, size = cex |
| 649 |
) + |
|
| 650 | ! |
theme_bw() + |
| 651 | ! |
labs( |
| 652 | ! |
x = "Sculpted Model Predictions", |
| 653 | ! |
y = "Learner Predictions" |
| 654 |
) |
|
| 655 | ||
| 656 | ! |
if (plot_only) {
|
| 657 | ! |
return(g) |
| 658 |
} else {
|
|
| 659 | ! |
return(list(plot = g, R2 = data.frame(description = descriptions, R2 = R2_mod_vs_approx))) |
| 660 |
} |
|
| 661 |
} |
|
| 662 | ||
| 663 | ||
| 664 | ||
| 665 |
#' Plot centered ICE profiles with centered PDP curves |
|
| 666 |
#' |
|
| 667 |
#' @param object Object of classes `rough` and `sculpture`. |
|
| 668 |
#' @param centered `logical`, centered ice plots? Defaults to `TRUE`. |
|
| 669 |
#' @param show_PDP `logical`, show PDP line? Defaults to `TRUE`. |
|
| 670 |
#' @param coloured `logical`, coloured curves? Defaults to `FALSE`. |
|
| 671 |
#' @param rug_sides "" for none, "b", for bottom, "trbl" for all 4 sides (see `geom_rug`) |
|
| 672 |
#' @param missings_spec Object of class `missings_specificatoin`. |
|
| 673 |
#' @param facet_spec Object of class `facet_specificatoin`. |
|
| 674 |
#' @param logodds_to_prob (`logical`) Only valid for binary response and sculptures built on |
|
| 675 |
#' the log-odds scale. Defaults to `FALSE` (i.e. no effect). |
|
| 676 |
#' If `TRUE`, then the y-values are transformed through inverse logit function 1 / (1 + exp(-x)). |
|
| 677 |
#' |
|
| 678 |
#' @return List of `ggplot`s (one for continuous features, one for discrete). |
|
| 679 |
#' @export |
|
| 680 |
#' |
|
| 681 |
#' @examples |
|
| 682 |
#' df <- mtcars |
|
| 683 |
#' df$vs <- as.factor(df$vs) |
|
| 684 |
#' model <- rpart::rpart( |
|
| 685 |
#' hp ~ mpg + carb + vs, |
|
| 686 |
#' data = df, |
|
| 687 |
#' control = rpart::rpart.control(minsplit = 10) |
|
| 688 |
#' ) |
|
| 689 |
#' model_predict <- function(x) predict(model, newdata = x) |
|
| 690 |
#' covariates <- c("mpg", "carb", "vs")
|
|
| 691 |
#' pm <- sample_marginals(df[covariates], n = 50, seed = 5) |
|
| 692 |
#' |
|
| 693 |
#' rs <- sculpt_rough( |
|
| 694 |
#' dat = pm, |
|
| 695 |
#' model_predict_fun = model_predict, |
|
| 696 |
#' n_ice = 10, |
|
| 697 |
#' seed = 1, |
|
| 698 |
#' verbose = 0 |
|
| 699 |
#' ) |
|
| 700 |
#' |
|
| 701 |
#' g_ice(rs)$continuous |
|
| 702 |
#' |
|
| 703 |
g_ice <- function(object, centered = TRUE, show_PDP = TRUE, coloured = FALSE, |
|
| 704 |
rug_sides = "b", |
|
| 705 |
missings_spec = missings_specification(), |
|
| 706 |
facet_spec = facet_specification(), |
|
| 707 |
logodds_to_prob = FALSE) {
|
|
| 708 | ! |
checkmate::assert_class(object, "rough") |
| 709 | ! |
checkmate::assert_flag(centered) |
| 710 | ! |
checkmate::assert_flag(show_PDP) |
| 711 | ! |
checkmate::assert_flag(coloured) |
| 712 | ! |
checkmate::assert_character(rug_sides, any.missing = FALSE, len = 1) |
| 713 | ! |
checkmate::assert_class(facet_spec, "facet_specification") |
| 714 | ! |
checkmate::assert_class(missings_spec, "missings_specification") |
| 715 | ! |
checkmate::assert_flag(logodds_to_prob) |
| 716 | ||
| 717 |
# transform missings into a list of values per each continuous variable |
|
| 718 | ! |
check_continuous <- vapply(object, "[[", logical(1), "is_discrete") |
| 719 | ! |
check_continuous <- names(Filter(isFALSE, check_continuous)) |
| 720 | ! |
if (length(missings_spec$values) == 1) {
|
| 721 | ! |
missings <- data.table(feature = check_continuous, x = missings_spec$values) |
| 722 | ! |
} else if (length(missings_spec$values) > 1) {
|
| 723 | ! |
missings <- data.table(feature = names(missings_spec$values), x = missings_spec$values) |
| 724 | ! |
checkmate::assert_names( |
| 725 | ! |
missings$feature, |
| 726 | ! |
subset.of = check_continuous, |
| 727 | ! |
.var.name = "missings_spec$values" |
| 728 |
) |
|
| 729 |
} else {
|
|
| 730 | ! |
missings <- NULL |
| 731 |
} |
|
| 732 | ||
| 733 | ! |
if (coloured & show_PDP) {
|
| 734 | ! |
stop("Coloured lines are only available without PDP, so please set `show_PDP = FALSE`.")
|
| 735 |
} |
|
| 736 | ||
| 737 |
# https://cran.r-project.org/web/packages/data.table/vignettes/datatable-importing.html#globals |
|
| 738 | ! |
. <- x <- x_ribbon <- line_id <- y_se <- feature <- NULL # due to NSE notes in R CMD check |
| 739 | ||
| 740 |
# resolve facet specification |
|
| 741 | ! |
rfs <- resolve_facet_specification(obj = object, fs = facet_spec) |
| 742 | ! |
object <- rfs$object |
| 743 | ||
| 744 |
# get continuous vars |
|
| 745 | ! |
idx_continuous_vars <- !vapply(object, "[[", logical(1), "is_discrete") |
| 746 | ! |
has_continuous <- any(idx_continuous_vars) |
| 747 | ! |
has_discrete <- any(!idx_continuous_vars) |
| 748 | ||
| 749 | ! |
pred_var <- if (centered) "ice_centered" else "ice" |
| 750 | ||
| 751 |
# continuous |
|
| 752 | ! |
if (has_continuous) {
|
| 753 |
# ICE |
|
| 754 | ! |
ice_continuous <- rbindlist( |
| 755 | ! |
lapply( |
| 756 | ! |
object[idx_continuous_vars], |
| 757 | ! |
function(v) {
|
| 758 | ! |
generate_ice_data( |
| 759 | ! |
predictions = v[[pred_var]], |
| 760 | ! |
x = v$x, |
| 761 | ! |
logodds_to_prob = logodds_to_prob |
| 762 |
) |
|
| 763 |
} |
|
| 764 |
), |
|
| 765 | ! |
idcol = "feature" |
| 766 |
) |
|
| 767 | ! |
ice_continuous <- unique(ice_continuous) |
| 768 | ! |
ice_continuous[, `:=`(type = "ICE Profiles", line_id = as.character(line_id))] |
| 769 | ||
| 770 |
# PDP |
|
| 771 | ! |
if (show_PDP) {
|
| 772 | ! |
pdp_continuous <- rbindlist( |
| 773 | ! |
lapply( |
| 774 | ! |
object[idx_continuous_vars], |
| 775 | ! |
function(v) {
|
| 776 | ! |
generate_pdp_data( |
| 777 | ! |
predictions = v[[pred_var]], |
| 778 | ! |
x = v$x, |
| 779 | ! |
logodds_to_prob = logodds_to_prob |
| 780 |
) |
|
| 781 |
} |
|
| 782 |
), |
|
| 783 | ! |
idcol = "feature" |
| 784 |
) |
|
| 785 | ! |
pdp_continuous[, y_se := ifelse(is.na(y_se), 0, y_se)] |
| 786 | ! |
pdp_continuous[, `:=`(line_id = "pdp", type = "Rough model (with SE)")] |
| 787 | ||
| 788 | ! |
dat_c <- rbind(ice_continuous, pdp_continuous, fill = TRUE) |
| 789 |
} else {
|
|
| 790 | ! |
dat_c <- ice_continuous |
| 791 |
} |
|
| 792 |
} else {
|
|
| 793 | ! |
dat_c <- data.table(y = numeric(0), feature = character(0)) |
| 794 |
} |
|
| 795 | ||
| 796 |
# discrete |
|
| 797 | ! |
if (has_discrete) {
|
| 798 |
# ICE |
|
| 799 | ! |
ice_discrete <- rbindlist( |
| 800 | ! |
lapply( |
| 801 | ! |
object[!idx_continuous_vars], |
| 802 | ! |
function(v) {
|
| 803 | ! |
generate_ice_data( |
| 804 | ! |
predictions = v[[pred_var]], |
| 805 | ! |
x = v$x, |
| 806 | ! |
logodds_to_prob = logodds_to_prob |
| 807 |
) |
|
| 808 |
} |
|
| 809 |
), |
|
| 810 | ! |
idcol = "feature" |
| 811 |
) |
|
| 812 | ! |
ice_discrete <- unique(ice_discrete) |
| 813 | ! |
ice_discrete[, `:=`(type = "ICE Profiles", line_id = as.character(line_id))] |
| 814 | ||
| 815 |
# PDP |
|
| 816 | ! |
if (show_PDP) {
|
| 817 | ! |
pdp_discrete <- rbindlist( |
| 818 | ! |
lapply( |
| 819 | ! |
object[!idx_continuous_vars], |
| 820 | ! |
function(v) {
|
| 821 | ! |
generate_pdp_data( |
| 822 | ! |
predictions = v[[pred_var]], |
| 823 | ! |
x = v$x, |
| 824 | ! |
logodds_to_prob = logodds_to_prob |
| 825 |
) |
|
| 826 |
} |
|
| 827 |
), |
|
| 828 | ! |
idcol = "feature" |
| 829 |
) |
|
| 830 | ! |
pdp_discrete[, `:=`(line_id = "pdp", type = "Rough model (with SE)")] |
| 831 | ||
| 832 | ! |
dat_d <- rbind(ice_discrete, pdp_discrete, fill = TRUE) |
| 833 |
} else {
|
|
| 834 | ! |
dat_d <- ice_discrete |
| 835 |
} |
|
| 836 |
} else {
|
|
| 837 | ! |
dat_d <- data.table(y = numeric(0), feature = character(0)) |
| 838 |
} |
|
| 839 | ||
| 840 |
# resolve y limits |
|
| 841 | ! |
y_limits <- resolve_y_limits(dat_c = dat_c, dat_d = dat_d, facet_scales = facet_spec$scales) |
| 842 | ||
| 843 |
# resolve facet sort - need to convert to factor |
|
| 844 | ! |
dat_c[, feature := factor(feature, levels = names(object)[idx_continuous_vars])] |
| 845 | ! |
dat_d[, feature := factor(feature, levels = names(object)[!idx_continuous_vars])] |
| 846 | ||
| 847 |
# resolve missings specification |
|
| 848 | ! |
rms <- resolve_missings_specification(dat_c = dat_c, ms = missings_spec, missings = missings) |
| 849 | ! |
dat_c <- rms$dat_c |
| 850 | ! |
missings <- rms$missings |
| 851 | ||
| 852 |
# graph for continuous |
|
| 853 | ! |
if (nrow(dat_c) > 0) {
|
| 854 | ! |
gc <- ggplot() |
| 855 | ||
| 856 | ! |
if (show_PDP) {
|
| 857 | ! |
gc <- gc + |
| 858 | ! |
geom_ribbon( |
| 859 | ! |
mapping = aes( |
| 860 | ! |
x = .data$x, |
| 861 | ! |
ymin = .data$y - .data$y_se, |
| 862 | ! |
ymax = .data$y + .data$y_se |
| 863 |
), |
|
| 864 | ! |
data = dat_c[line_id == "pdp"], |
| 865 | ! |
na.rm = T, alpha = 0.4, colour = "lightblue", fill = "lightblue" |
| 866 |
) |
|
| 867 |
} |
|
| 868 | ||
| 869 | ! |
if (coloured) {
|
| 870 | ! |
gc <- gc + |
| 871 | ! |
geom_line( |
| 872 | ! |
mapping = aes(x = .data$x, y = .data$y, colour = .data$line_id), |
| 873 | ! |
linewidth = 1, |
| 874 | ! |
alpha = 0.3, |
| 875 | ! |
data = dat_c, |
| 876 | ! |
na.rm = F |
| 877 |
) |
|
| 878 |
} else {
|
|
| 879 | ! |
gc <- gc + |
| 880 | ! |
geom_line( |
| 881 | ! |
mapping = aes( |
| 882 | ! |
x = .data$x, y = .data$y, |
| 883 | ! |
group = .data$line_id, colour = .data$type |
| 884 |
), |
|
| 885 | ! |
data = dat_c, |
| 886 | ! |
na.rm = F |
| 887 |
) |
|
| 888 |
} |
|
| 889 | ||
| 890 | ! |
gc <- gc + |
| 891 | ! |
geom_rug( |
| 892 | ! |
mapping = aes(x = .data$x, y = .data$y), |
| 893 | ! |
data = dat_c[line_id == 1], |
| 894 | ! |
na.rm = F, |
| 895 | ! |
sides = rug_sides |
| 896 |
) + |
|
| 897 | ! |
facet_wrap( |
| 898 | ! |
"feature", |
| 899 | ! |
scales = rfs$scales, ncol = rfs$ncol_c, |
| 900 | ! |
labeller = as_labeller(rfs$labels) |
| 901 |
) |
|
| 902 | ||
| 903 | ! |
if (!coloured) {
|
| 904 | ! |
gc <- gc + |
| 905 | ! |
scale_color_manual( |
| 906 | ! |
values = c( |
| 907 | ! |
"ICE Profiles" = ifelse(show_PDP, "gray60", "black"), |
| 908 | ! |
"Rough model (with SE)" = "blue" |
| 909 | ! |
)[c(T, show_PDP)], |
| 910 | ! |
name = "" |
| 911 |
) |
|
| 912 |
} |
|
| 913 | ||
| 914 |
# add missings lines |
|
| 915 | ! |
if (!is.null(missings)) {
|
| 916 | ! |
if (missings_spec$vline) {
|
| 917 | ! |
gc <- gc + |
| 918 | ! |
geom_vline( |
| 919 | ! |
mapping = aes(xintercept = .data$x), |
| 920 | ! |
data = missings, |
| 921 | ! |
linetype = "dotted" |
| 922 |
) |
|
| 923 |
} |
|
| 924 | ! |
if (missings_spec$hline) {
|
| 925 | ! |
gc <- gc + |
| 926 | ! |
geom_hline( |
| 927 | ! |
mapping = aes(yintercept = .data$y, linetype = "Score for Missing Feature"), |
| 928 | ! |
data = missings |
| 929 |
) + |
|
| 930 | ! |
scale_linetype_manual(values = c("Score for Missing Feature" = "dotted"), name = NULL)
|
| 931 |
} |
|
| 932 |
} |
|
| 933 | ||
| 934 | ! |
gc <- gc + |
| 935 | ! |
labs(x = "Features", y = "Feature Score", caption = "", colour = NULL) + |
| 936 | ! |
ylim(y_limits) + |
| 937 | ! |
theme_bw() |
| 938 | ||
| 939 | ! |
if (show_PDP) {
|
| 940 | ! |
gc <- gc + guides(colour = guide_legend(override.aes = list(size = 1))) |
| 941 |
} else {
|
|
| 942 | ! |
gc <- gc + guides(colour = guide_none()) |
| 943 |
} |
|
| 944 | ||
| 945 | ! |
if (coloured) {
|
| 946 | ! |
dat_c2 <- dat_c |
| 947 | ! |
dat_c$x_perc <- ecdf(dat_c$x)(dat_c$x) |
| 948 | ! |
gc2 <- ggplot() + |
| 949 | ! |
geom_line( |
| 950 | ! |
mapping = aes(x = .data$x_perc, y = .data$y, colour = .data$line_id), |
| 951 | ! |
linewidth = 1, |
| 952 | ! |
alpha = 0.3, |
| 953 | ! |
data = dat_c, |
| 954 | ! |
na.rm = F |
| 955 |
) + |
|
| 956 | ! |
geom_rug( |
| 957 | ! |
mapping = aes(x = .data$x_perc, y = .data$y), |
| 958 | ! |
data = dat_c[line_id == 1], |
| 959 | ! |
na.rm = F, |
| 960 | ! |
sides = rug_sides |
| 961 |
) + |
|
| 962 | ! |
facet_wrap( |
| 963 | ! |
"feature", |
| 964 | ! |
scales = rfs$scales, ncol = rfs$ncol_c, |
| 965 | ! |
labeller = as_labeller(rfs$labels) |
| 966 |
) + |
|
| 967 | ! |
labs(x = "Features", y = "Feature Score", caption = "", colour = NULL) + |
| 968 | ! |
ylim(y_limits) + |
| 969 | ! |
theme_bw() + |
| 970 | ! |
guides(colour = guide_none()) |
| 971 |
} |
|
| 972 |
} else {
|
|
| 973 | ! |
gc <- NULL |
| 974 |
} |
|
| 975 | ||
| 976 |
# graph for discrete |
|
| 977 | ! |
if (nrow(dat_d) > 0) {
|
| 978 | ! |
gd <- ggplot() |
| 979 | ||
| 980 | ! |
if (show_PDP) {
|
| 981 | ! |
dat_d[line_id == "pdp", x_ribbon := as.numeric(droplevels(as.factor(x))), by = .(feature)] |
| 982 | ! |
gd <- gd + |
| 983 | ! |
geom_point( |
| 984 | ! |
mapping = aes(x = .data$x, y = .data$y, colour = .data$type), |
| 985 | ! |
data = dat_d[line_id == "pdp"], |
| 986 | ! |
na.rm = T |
| 987 |
) + |
|
| 988 | ! |
geom_ribbon( |
| 989 | ! |
mapping = aes( |
| 990 | ! |
x = .data$x_ribbon, |
| 991 | ! |
ymin = .data$y - .data$y_se, |
| 992 | ! |
ymax = .data$y + .data$y_se |
| 993 |
), |
|
| 994 | ! |
data = dat_d[line_id == "pdp"], |
| 995 | ! |
na.rm = T, alpha = 0.4, colour = "lightblue", fill = "lightblue" |
| 996 |
) |
|
| 997 |
} |
|
| 998 | ||
| 999 | ! |
if (coloured) {
|
| 1000 | ! |
gd <- gd + |
| 1001 | ! |
geom_line( |
| 1002 | ! |
mapping = aes(x = .data$x, y = .data$y, colour = .data$line_id, group = .data$line_id), |
| 1003 | ! |
linewidth = 1, |
| 1004 | ! |
alpha = 0.3, |
| 1005 | ! |
data = dat_d, |
| 1006 | ! |
na.rm = F |
| 1007 |
) |
|
| 1008 |
} else {
|
|
| 1009 | ! |
gd <- gd + |
| 1010 | ! |
geom_line( |
| 1011 | ! |
mapping = aes( |
| 1012 | ! |
x = .data$x, y = .data$y, |
| 1013 | ! |
colour = .data$type, group = .data$line_id |
| 1014 |
), |
|
| 1015 | ! |
data = dat_d, |
| 1016 | ! |
na.rm = F |
| 1017 |
) |
|
| 1018 |
} |
|
| 1019 | ||
| 1020 | ! |
gd <- gd + |
| 1021 | ! |
facet_wrap( |
| 1022 | ! |
"feature", scales = rfs$scales, ncol = rfs$ncol_d, |
| 1023 | ! |
labeller = as_labeller(rfs$labels) |
| 1024 |
) |
|
| 1025 | ||
| 1026 | ! |
if (!coloured) {
|
| 1027 | ! |
gd <- gd + |
| 1028 | ! |
scale_color_manual( |
| 1029 | ! |
values = c( |
| 1030 | ! |
"ICE Profiles" = ifelse(show_PDP, "gray60", "black"), |
| 1031 | ! |
"Rough model (with SE)" = "blue" |
| 1032 | ! |
)[c(T, show_PDP)], |
| 1033 | ! |
name = "" |
| 1034 |
) |
|
| 1035 |
} |
|
| 1036 | ||
| 1037 | ! |
gd <- gd + |
| 1038 | ! |
labs(x = "Features", y = "Feature Score", caption = "", colour = NULL, linetype = NULL) + |
| 1039 | ! |
ylim(y_limits) + |
| 1040 | ! |
theme_bw() |
| 1041 | ||
| 1042 | ! |
if (show_PDP) {
|
| 1043 | ! |
gd <- gd + guides(colour = guide_legend(override.aes = list(size = 1))) |
| 1044 |
} else {
|
|
| 1045 | ! |
gd <- gd + guides(colour = guide_none()) |
| 1046 |
} |
|
| 1047 |
} else {
|
|
| 1048 | ! |
gd <- NULL |
| 1049 |
} |
|
| 1050 | ||
| 1051 | ! |
if (coloured) {
|
| 1052 | ! |
return(list(continuous = gc, discrete = gd, perc = gc2)) |
| 1053 |
} else {
|
|
| 1054 | ! |
return(list(continuous = gc, discrete = gd)) |
| 1055 |
} |
|
| 1056 |
} |
|
| 1057 | ||
| 1058 | ||
| 1059 |
#' Plot component functions |
|
| 1060 |
#' |
|
| 1061 |
#' @param object Object of class `sculpture`. |
|
| 1062 |
#' @inheritParams g_ice |
|
| 1063 |
#' |
|
| 1064 |
#' @return List of `ggplot`s (one for continuous features, one for discrete). |
|
| 1065 |
#' @export |
|
| 1066 |
#' |
|
| 1067 |
#' @examples |
|
| 1068 |
#' df <- mtcars |
|
| 1069 |
#' df$vs <- as.factor(df$vs) |
|
| 1070 |
#' model <- rpart::rpart( |
|
| 1071 |
#' hp ~ mpg + carb + vs, |
|
| 1072 |
#' data = df, |
|
| 1073 |
#' control = rpart::rpart.control(minsplit = 10) |
|
| 1074 |
#' ) |
|
| 1075 |
#' model_predict <- function(x) predict(model, newdata = x) |
|
| 1076 |
#' covariates <- c("mpg", "carb", "vs")
|
|
| 1077 |
#' pm <- sample_marginals(df[covariates], n = 50, seed = 5) |
|
| 1078 |
#' |
|
| 1079 |
#' rs <- sculpt_rough( |
|
| 1080 |
#' dat = pm, |
|
| 1081 |
#' model_predict_fun = model_predict, |
|
| 1082 |
#' n_ice = 10, |
|
| 1083 |
#' seed = 1, |
|
| 1084 |
#' verbose = 0 |
|
| 1085 |
#' ) |
|
| 1086 |
#' |
|
| 1087 |
#' ds <- sculpt_detailed_gam(rs) |
|
| 1088 |
#' |
|
| 1089 |
#' g_component(ds)$continuous |
|
| 1090 |
#' |
|
| 1091 |
g_component <- function(object, rug_sides = "b", |
|
| 1092 |
missings_spec = missings_specification(), |
|
| 1093 |
facet_spec = facet_specification(), |
|
| 1094 |
logodds_to_prob = FALSE) {
|
|
| 1095 | ! |
checkmate::assert_class(object, "sculpture") |
| 1096 | ! |
checkmate::assert_character(rug_sides, any.missing = FALSE, len = 1) |
| 1097 | ! |
checkmate::assert_class(missings_spec, "missings_specification") |
| 1098 | ! |
checkmate::assert_class(facet_spec, "facet_specification") |
| 1099 | ! |
checkmate::assert_flag(logodds_to_prob) |
| 1100 | ||
| 1101 |
# transform missings into a list of values per each continuous variable |
|
| 1102 | ! |
check_continuous <- vapply(object, "[[", logical(1), "is_discrete") |
| 1103 | ! |
check_continuous <- names(Filter(isFALSE, check_continuous)) |
| 1104 | ! |
if (length(missings_spec$values) == 1) {
|
| 1105 | ! |
missings <- data.table(feature = check_continuous, x = missings_spec$values) |
| 1106 | ! |
} else if (length(missings_spec$values) > 1) {
|
| 1107 | ! |
missings <- data.table(feature = names(missings_spec$values), x = missings_spec$values) |
| 1108 | ! |
checkmate::assert_names( |
| 1109 | ! |
missings$feature, |
| 1110 | ! |
subset.of = check_continuous, |
| 1111 | ! |
.var.name = "missings_spec$values" |
| 1112 |
) |
|
| 1113 |
} else {
|
|
| 1114 | ! |
missings <- NULL |
| 1115 |
} |
|
| 1116 | ||
| 1117 |
# https://cran.r-project.org/web/packages/data.table/vignettes/datatable-importing.html#globals |
|
| 1118 | ! |
feature <- NULL # due to NSE notes in R CMD check |
| 1119 | ||
| 1120 |
# resolve facet specification |
|
| 1121 | ! |
rfs <- resolve_facet_specification(obj = object, fs = facet_spec) |
| 1122 | ! |
object <- rfs$object |
| 1123 | ||
| 1124 |
# get continuous vars |
|
| 1125 | ! |
idx_continuous_vars <- !vapply(object, "[[", logical(1), "is_discrete") |
| 1126 | ! |
has_continuous <- any(idx_continuous_vars) |
| 1127 | ! |
has_discrete <- any(!idx_continuous_vars) |
| 1128 | ||
| 1129 | ! |
if (has_continuous) {
|
| 1130 | ! |
dat_c <- rbindlist( |
| 1131 | ! |
lapply( |
| 1132 | ! |
object[idx_continuous_vars], |
| 1133 | ! |
function(v) {
|
| 1134 | ! |
data.table( |
| 1135 | ! |
x = v$x, |
| 1136 | ! |
y = `if`(logodds_to_prob, inv.logit(v$predict(v$x)), v$predict(v$x)) |
| 1137 |
) |
|
| 1138 |
} |
|
| 1139 |
), |
|
| 1140 | ! |
idcol = "feature" |
| 1141 |
) |
|
| 1142 |
} else {
|
|
| 1143 | ! |
dat_c <- data.table(feature = character(0)) |
| 1144 |
} |
|
| 1145 | ||
| 1146 | ! |
if (has_discrete) {
|
| 1147 | ! |
dat_d <- rbindlist( |
| 1148 | ! |
lapply( |
| 1149 | ! |
object[!idx_continuous_vars], |
| 1150 | ! |
function(v) {
|
| 1151 | ! |
data.table( |
| 1152 | ! |
x = v$x, |
| 1153 | ! |
y = `if`(logodds_to_prob, inv.logit(v$predict(v$x)), v$predict(v$x)) |
| 1154 |
) |
|
| 1155 |
} |
|
| 1156 |
), |
|
| 1157 | ! |
idcol = "feature" |
| 1158 |
) |
|
| 1159 |
} else {
|
|
| 1160 | ! |
dat_d <- data.table(feature = character(0)) |
| 1161 |
} |
|
| 1162 | ||
| 1163 |
# resolve y limits |
|
| 1164 | ! |
y_limits <- resolve_y_limits(dat_c = dat_c, dat_d = dat_d, facet_scales = facet_spec$scales) |
| 1165 | ||
| 1166 |
# resolve facet sort - need to convert to factor |
|
| 1167 | ! |
dat_c[, feature := factor(feature, levels = names(object)[idx_continuous_vars])] |
| 1168 | ! |
dat_d[, feature := factor(feature, levels = names(object)[!idx_continuous_vars])] |
| 1169 | ||
| 1170 |
# resolve missings specification |
|
| 1171 | ! |
rms <- resolve_missings_specification(dat_c = dat_c, ms = missings_spec, missings = missings) |
| 1172 | ! |
dat_c <- rms$dat_c |
| 1173 | ! |
missings <- rms$missings |
| 1174 | ||
| 1175 | ! |
if (missings_spec$hline) {
|
| 1176 | ! |
legend_model_name <- paste(stringr::str_to_title(class(object)[1]), "Model Component") |
| 1177 | ! |
line_mapping <- aes( |
| 1178 | ! |
x = .data$x, y = .data$y, |
| 1179 | ! |
group = .data$feature, linetype = .data$legend_model_name |
| 1180 |
) |
|
| 1181 |
} else {
|
|
| 1182 | ! |
line_mapping <- aes(x = .data$x, y = .data$y, group = .data$feature) |
| 1183 |
} |
|
| 1184 | ||
| 1185 | ! |
if (nrow(dat_c) > 0) {
|
| 1186 | ! |
gc <- ggplot(dat_c) + |
| 1187 | ! |
geom_line(mapping = line_mapping) + |
| 1188 | ! |
geom_rug( |
| 1189 | ! |
mapping = aes(x = .data$x, y = .data$y), |
| 1190 | ! |
na.rm = F, |
| 1191 | ! |
sides = rug_sides |
| 1192 |
) + |
|
| 1193 | ! |
facet_wrap( |
| 1194 | ! |
"feature", scales = rfs$scales, ncol = rfs$ncol_c, |
| 1195 | ! |
labeller = as_labeller(rfs$labels) |
| 1196 |
) |
|
| 1197 | ||
| 1198 |
# add missings lines |
|
| 1199 | ! |
if (!is.null(missings)) {
|
| 1200 | ! |
if (missings_spec$vline) {
|
| 1201 | ! |
gc <- gc + |
| 1202 | ! |
geom_vline( |
| 1203 | ! |
mapping = aes(xintercept = .data$x), |
| 1204 | ! |
data = missings, |
| 1205 | ! |
linetype = "dotted" |
| 1206 |
) |
|
| 1207 |
} |
|
| 1208 | ! |
if (missings_spec$hline) {
|
| 1209 | ! |
gc <- gc + |
| 1210 | ! |
geom_hline( |
| 1211 | ! |
mapping = aes(yintercept = .data$y, linetype = "Score for Missing Feature"), |
| 1212 | ! |
data = missings |
| 1213 |
) + |
|
| 1214 | ! |
scale_linetype_manual( |
| 1215 | ! |
values = structure( |
| 1216 | ! |
c("dotted", "solid"),
|
| 1217 | ! |
names = c("Score for Missing Feature", legend_model_name)
|
| 1218 |
), |
|
| 1219 | ! |
name = NULL |
| 1220 |
) |
|
| 1221 |
} |
|
| 1222 |
} |
|
| 1223 | ||
| 1224 | ! |
gc <- gc + |
| 1225 | ! |
labs(x = "Features", y = "Feature Score") + |
| 1226 | ! |
ylim(y_limits) + |
| 1227 | ! |
theme_bw() |
| 1228 |
} else {
|
|
| 1229 | ! |
gc <- NULL |
| 1230 |
} |
|
| 1231 | ||
| 1232 | ! |
if (nrow(dat_d) > 0) {
|
| 1233 | ! |
gd <- ggplot(dat_d) + |
| 1234 | ! |
geom_line(aes(x = .data$x, y = .data$y, group = .data$feature)) + |
| 1235 | ! |
facet_wrap( |
| 1236 | ! |
"feature", scales = rfs$scales, ncol = rfs$ncol_d, |
| 1237 | ! |
labeller = as_labeller(rfs$labels) |
| 1238 |
) + |
|
| 1239 | ! |
labs(x = "Features", y = "Feature Score") + |
| 1240 | ! |
ylim(y_limits) + |
| 1241 | ! |
theme_bw() |
| 1242 |
} else {
|
|
| 1243 | ! |
gd <- NULL |
| 1244 |
} |
|
| 1245 | ||
| 1246 | ! |
return(list(continuous = gc, discrete = gd)) |
| 1247 |
} |
|
| 1248 | ||
| 1249 |
#' Plot comparison of component functions |
|
| 1250 |
#' |
|
| 1251 |
#' @param sculptures List of objects of classes `sculpture`. |
|
| 1252 |
#' @param descriptions Character vector with model names. Same length as `sculptures`. |
|
| 1253 |
#' @inheritParams g_ice |
|
| 1254 |
#' @param hue_coloring Logical, use hue-based coloring? |
|
| 1255 |
#' Defaults to FALSE, meaning that predefined colors will be used instead. |
|
| 1256 |
#' |
|
| 1257 |
#' @details The first element of `sculptures` works as a reference sculpture. |
|
| 1258 |
#' All other sculptures must have a subset of variables with respect to the first one |
|
| 1259 |
#' (i.e. the same variables or less, but not new ones). |
|
| 1260 |
#' This allows to visualize polished together with non-polished sculptures, |
|
| 1261 |
#' if the non-polished one is specified as the first one. |
|
| 1262 |
#' |
|
| 1263 |
#' @return List of `ggplot`s (one for continuous features, one for discrete). |
|
| 1264 |
#' @export |
|
| 1265 |
#' |
|
| 1266 |
#' @examples |
|
| 1267 |
#' df <- mtcars |
|
| 1268 |
#' df$vs <- as.factor(df$vs) |
|
| 1269 |
#' model <- rpart::rpart( |
|
| 1270 |
#' hp ~ mpg + carb + vs, |
|
| 1271 |
#' data = df, |
|
| 1272 |
#' control = rpart::rpart.control(minsplit = 10) |
|
| 1273 |
#' ) |
|
| 1274 |
#' model_predict <- function(x) predict(model, newdata = x) |
|
| 1275 |
#' covariates <- c("mpg", "carb", "vs")
|
|
| 1276 |
#' pm <- sample_marginals(df[covariates], n = 50, seed = 5) |
|
| 1277 |
#' |
|
| 1278 |
#' rs <- sculpt_rough( |
|
| 1279 |
#' dat = pm, |
|
| 1280 |
#' model_predict_fun = model_predict, |
|
| 1281 |
#' n_ice = 10, |
|
| 1282 |
#' seed = 1, |
|
| 1283 |
#' verbose = 0 |
|
| 1284 |
#' ) |
|
| 1285 |
#' |
|
| 1286 |
#' ds <- sculpt_detailed_gam(rs) |
|
| 1287 |
#' |
|
| 1288 |
#' # this keeps only "mpg" |
|
| 1289 |
#' ps <- sculpt_polished(ds, k = 1) |
|
| 1290 |
#' |
|
| 1291 |
#' # also define simple labels |
|
| 1292 |
#' labels <- structure( |
|
| 1293 |
#' toupper(covariates), # labels |
|
| 1294 |
#' names = covariates # current (old) names |
|
| 1295 |
#' ) |
|
| 1296 |
#' |
|
| 1297 |
#' # Component functions of "Detailed" and "Polished" are the same for "mpg" variable, |
|
| 1298 |
#' # therefore red curve overlays the blue one for "mpg" |
|
| 1299 |
#' comp <- g_comparison( |
|
| 1300 |
#' sculptures = list(rs, ds, ps), |
|
| 1301 |
#' descriptions = c("Rough", "Detailed", "Polished"),
|
|
| 1302 |
#' facet_spec = facet_specification(ncol = 2, labels = labels) |
|
| 1303 |
#' ) |
|
| 1304 |
#' comp$continuous |
|
| 1305 |
#' comp$discrete |
|
| 1306 |
#' |
|
| 1307 |
g_comparison <- function(sculptures, descriptions, rug_sides = "b", |
|
| 1308 |
missings_spec = missings_specification(), |
|
| 1309 |
facet_spec = facet_specification(), |
|
| 1310 |
hue_coloring = FALSE, |
|
| 1311 |
logodds_to_prob = FALSE) {
|
|
| 1312 | ! |
checkmate::assert_list(sculptures, types = "sculpture") |
| 1313 | ! |
checkmate::assert_character(descriptions, len = length(sculptures)) |
| 1314 | ! |
checkmate::assert_character(rug_sides, any.missing = FALSE, len = 1) |
| 1315 | ! |
checkmate::assert_class(facet_spec, "facet_specification") |
| 1316 | ! |
checkmate::assert_flag(hue_coloring) |
| 1317 | ! |
checkmate::assert_flag(logodds_to_prob) |
| 1318 | ||
| 1319 | ! |
names_sc_1 <- names(sculptures[[1]]) |
| 1320 | ! |
check_names <- vapply(sculptures, function(sc) all(names(sc) %in% names_sc_1), logical(1)) |
| 1321 | ! |
if (!all(check_names)) {
|
| 1322 | ! |
stop("All sculptures must be subsets of the first sculpture (in terms of variables).")
|
| 1323 |
} |
|
| 1324 | ||
| 1325 |
# transform missings into a list of values per each continuous variable |
|
| 1326 | ! |
check_continuous <- vapply(sculptures[[1]], "[[", logical(1), "is_discrete") |
| 1327 | ! |
check_continuous <- names(Filter(isFALSE, check_continuous)) |
| 1328 | ! |
if (length(missings_spec$values) == 1) {
|
| 1329 | ! |
missings <- data.table(feature = check_continuous, x = missings_spec$values) |
| 1330 | ! |
} else if (length(missings_spec$values) > 1) {
|
| 1331 | ! |
missings <- data.table(feature = names(missings_spec$values), x = missings_spec$values) |
| 1332 | ! |
checkmate::assert_names( |
| 1333 | ! |
missings$feature, |
| 1334 | ! |
subset.of = check_continuous, |
| 1335 | ! |
.var.name = "missings_spec$values" |
| 1336 |
) |
|
| 1337 |
} else {
|
|
| 1338 | ! |
missings <- NULL |
| 1339 |
} |
|
| 1340 | ||
| 1341 |
# https://cran.r-project.org/web/packages/data.table/vignettes/datatable-importing.html#globals |
|
| 1342 | ! |
feature <- Model <- NULL # due to NSE notes in R CMD check |
| 1343 | ||
| 1344 |
# resolve facet specification |
|
| 1345 | ! |
rfs <- resolve_facet_specification(obj = sculptures[[1]], fs = facet_spec) |
| 1346 | ! |
sculptures[[1]] <- rfs$object |
| 1347 | ||
| 1348 |
# get continuous vars |
|
| 1349 | ! |
idx_continuous_vars <- !vapply(sculptures[[1]], "[[", logical(1), "is_discrete") |
| 1350 | ! |
has_continuous <- any(idx_continuous_vars) |
| 1351 | ! |
has_discrete <- any(!idx_continuous_vars) |
| 1352 | ||
| 1353 | ! |
if (has_discrete) {
|
| 1354 | ! |
dat_d <- rbindlist( |
| 1355 | ! |
lapply( |
| 1356 | ! |
seq_along(sculptures), |
| 1357 | ! |
function(i) {
|
| 1358 | ! |
rbindlist( |
| 1359 | ! |
lapply( |
| 1360 | ! |
sculptures[[i]][vapply(sculptures[[i]], "[[", logical(1), "is_discrete")], |
| 1361 | ! |
function(v) {
|
| 1362 | ! |
data.table( |
| 1363 | ! |
x = v$x, |
| 1364 | ! |
y = `if`(logodds_to_prob, inv.logit(v$predict(v$x)), v$predict(v$x)), |
| 1365 | ! |
Model = descriptions[i] |
| 1366 |
) |
|
| 1367 |
} |
|
| 1368 |
), |
|
| 1369 | ! |
idcol = "feature" |
| 1370 |
) |
|
| 1371 |
} |
|
| 1372 |
) |
|
| 1373 |
) |
|
| 1374 | ! |
dat_d$Model <- factor(dat_d$Model, levels = descriptions) |
| 1375 |
} else {
|
|
| 1376 | ! |
dat_d <- data.table(feature = character(0)) |
| 1377 |
} |
|
| 1378 | ||
| 1379 | ! |
if (has_continuous) {
|
| 1380 | ! |
dat_c <- rbindlist( |
| 1381 | ! |
lapply( |
| 1382 | ! |
seq_along(sculptures), |
| 1383 | ! |
function(i) {
|
| 1384 | ! |
rbindlist( |
| 1385 | ! |
lapply( |
| 1386 | ! |
sculptures[[i]][!vapply(sculptures[[i]], "[[", logical(1), "is_discrete")], |
| 1387 | ! |
function(v) {
|
| 1388 | ! |
data.table( |
| 1389 | ! |
x = v$x, |
| 1390 | ! |
y = `if`(logodds_to_prob, inv.logit(v$predict(v$x)), v$predict(v$x)), |
| 1391 | ! |
Model = descriptions[i] |
| 1392 |
) |
|
| 1393 |
} |
|
| 1394 |
), |
|
| 1395 | ! |
idcol = "feature" |
| 1396 |
) |
|
| 1397 |
} |
|
| 1398 |
) |
|
| 1399 |
) |
|
| 1400 | ! |
dat_c$Model <- factor(dat_c$Model, levels = descriptions) |
| 1401 |
} else {
|
|
| 1402 | ! |
dat_c <- data.table(feature = character(0)) |
| 1403 |
} |
|
| 1404 | ||
| 1405 |
# resolve y limits |
|
| 1406 | ! |
y_limits <- resolve_y_limits(dat_c = dat_c, dat_d = dat_d, facet_scales = facet_spec$scales) |
| 1407 | ||
| 1408 |
# resolve facet sort - need to convert to factor |
|
| 1409 | ! |
dat_c[, feature := factor(feature, levels = names(sculptures[[1]])[idx_continuous_vars])] |
| 1410 | ! |
dat_d[, feature := factor(feature, levels = names(sculptures[[1]])[!idx_continuous_vars])] |
| 1411 | ||
| 1412 |
# resolve missings specification |
|
| 1413 | ! |
rms <- resolve_missings_specification(dat_c = dat_c, ms = missings_spec, missings = missings) |
| 1414 | ! |
dat_c <- rms$dat_c |
| 1415 | ! |
missings <- rms$missings |
| 1416 | ||
| 1417 | ! |
colours <- structure( |
| 1418 | ! |
ms_color(length(sculptures), hue_coloring = hue_coloring), |
| 1419 | ! |
names = descriptions |
| 1420 |
) |
|
| 1421 | ||
| 1422 | ! |
if (nrow(dat_c) > 0) {
|
| 1423 | ! |
gc <- ggplot(dat_c) + |
| 1424 | ! |
geom_line( |
| 1425 | ! |
aes( |
| 1426 | ! |
x = .data$x, |
| 1427 | ! |
y = .data$y, |
| 1428 | ! |
colour = .data$Model, |
| 1429 | ! |
group = interaction(.data$feature, .data$Model) |
| 1430 |
) |
|
| 1431 |
) + |
|
| 1432 | ! |
geom_rug( |
| 1433 | ! |
mapping = aes(x = .data$x, y = .data$y), |
| 1434 | ! |
data = dat_c[Model == descriptions[1]], |
| 1435 | ! |
na.rm = F, |
| 1436 | ! |
sides = rug_sides |
| 1437 |
) + |
|
| 1438 | ! |
facet_wrap( |
| 1439 | ! |
"feature", scales = rfs$scales, ncol = rfs$ncol_c, |
| 1440 | ! |
labeller = as_labeller(rfs$labels) |
| 1441 |
) + |
|
| 1442 | ! |
scale_color_manual(values = colours) |
| 1443 | ||
| 1444 |
# add missings lines |
|
| 1445 | ! |
if (!is.null(missings)) {
|
| 1446 | ! |
if (missings_spec$vline) {
|
| 1447 | ! |
gc <- gc + |
| 1448 | ! |
geom_vline( |
| 1449 | ! |
mapping = aes(xintercept = .data$x), |
| 1450 | ! |
data = missings, |
| 1451 | ! |
linetype = "dotted" |
| 1452 |
) |
|
| 1453 |
} |
|
| 1454 | ! |
if (missings_spec$hline) {
|
| 1455 | ! |
gc <- gc + |
| 1456 | ! |
geom_hline( |
| 1457 | ! |
mapping = aes( |
| 1458 | ! |
yintercept = .data$y, linetype = "Score for Missing Feature", |
| 1459 | ! |
color = Model |
| 1460 |
), |
|
| 1461 | ! |
data = missings |
| 1462 |
) + |
|
| 1463 | ! |
scale_linetype_manual(values = c("Score for Missing Feature" = "dotted"), name = NULL)
|
| 1464 |
} |
|
| 1465 |
} |
|
| 1466 | ||
| 1467 | ! |
gc <- gc + |
| 1468 | ! |
labs(x = "Features", y = "Feature Score") + |
| 1469 | ! |
ylim(y_limits) + |
| 1470 | ! |
theme_bw() |
| 1471 |
} else {
|
|
| 1472 | ! |
gc <- NULL |
| 1473 |
} |
|
| 1474 | ||
| 1475 | ! |
if (nrow(dat_d) > 0) {
|
| 1476 | ! |
gd <- ggplot(dat_d) + |
| 1477 | ! |
geom_line( |
| 1478 | ! |
aes( |
| 1479 | ! |
x = .data$x, |
| 1480 | ! |
y = .data$y, |
| 1481 | ! |
colour = .data$Model, |
| 1482 | ! |
group = interaction(.data$feature, .data$Model) |
| 1483 |
) |
|
| 1484 |
) + |
|
| 1485 | ! |
facet_wrap( |
| 1486 | ! |
"feature", scales = rfs$scales, ncol = rfs$ncol_d, |
| 1487 | ! |
labeller = as_labeller(rfs$labels) |
| 1488 |
) + |
|
| 1489 | ! |
scale_color_manual(values = colours) + |
| 1490 | ! |
labs(x = "Features", y = "Feature Score") + |
| 1491 | ! |
ylim(y_limits) + |
| 1492 | ! |
theme_bw() |
| 1493 |
} else {
|
|
| 1494 | ! |
gd <- NULL |
| 1495 |
} |
|
| 1496 | ||
| 1497 | ! |
return(list(continuous = gc, discrete = gd)) |
| 1498 |
} |
| 1 |
#' Create ICE curves at quantiles |
|
| 2 |
#' @keywords internal |
|
| 3 |
#' |
|
| 4 |
#' @param object Object of class sculpture (rough, detailed) |
|
| 5 |
#' @param new_data Data to make quantiles on |
|
| 6 |
#' @param var_name String specifying which variable to generate ICE |
|
| 7 |
#' @param qtiles Quantiles to generate ICE curves |
|
| 8 |
#' @param task Prediction task type (regression or classification) |
|
| 9 |
#' |
|
| 10 |
#' @return Predictions |
|
| 11 |
#' |
|
| 12 |
#' @details |
|
| 13 |
#' It should be amenable to any 1st-order model without interaction terms, |
|
| 14 |
#' however not implemented yet, such as handling `predict()` function output |
|
| 15 |
#' for binary endpoint |
|
| 16 |
#' |
|
| 17 |
calc_ice_quantile <- function(object, new_data, var_name, qtiles = seq(0, 1, by = 0.1), |
|
| 18 |
task = "regression") {
|
|
| 19 | 1x |
checkmate::assert_class(object, "sculpture") |
| 20 | 1x |
match.arg(task, c("regression", "classification"))
|
| 21 | ||
| 22 |
# https://cran.r-project.org/web/packages/data.table/vignettes/datatable-importing.html#globals |
|
| 23 | 1x |
median <- quantile <- rgb <- NULL # due to NSE notes in R CMD check |
| 24 | ||
| 25 |
# Predict for all samples after replacing the variable with 1st value, |
|
| 26 |
# then take quantiles |
|
| 27 | 1x |
cov_1st_val <- new_data[[var_name]][1] |
| 28 | ||
| 29 | 1x |
new_data_with_1st_val <- new_data |
| 30 | 1x |
new_data_with_1st_val[[var_name]] <- cov_1st_val |
| 31 | ||
| 32 | 1x |
pred_at_1st_val <- predict(object, new_data_with_1st_val) |
| 33 | ||
| 34 | 1x |
pred_qtile_at_1st_qtiles <- quantile(pred_at_1st_val, qtiles) |
| 35 | ||
| 36 | 1x |
pred_at_1st_qtiles <- data.frame(pred_at_1st = pred_qtile_at_1st_qtiles, qtile = qtiles) |
| 37 | ||
| 38 |
# Separately predict for all values of the cov of interest |
|
| 39 | 1x |
preds_for_adjust_1 <- merge( |
| 40 | 1x |
new_data[1, setdiff(colnames(new_data), var_name)], |
| 41 | 1x |
unique(new_data[var_name]) |
| 42 |
) |
|
| 43 | 1x |
preds_for_adjust_1$pred <- predict(object, newdata = preds_for_adjust_1) |
| 44 | ||
| 45 | ||
| 46 |
# Get the pred at the first element, because |
|
| 47 |
# it is what was selected for cov_1st_val |
|
| 48 | 1x |
pred_for_adjust_at_1st <- preds_for_adjust_1$pred[1] |
| 49 | ||
| 50 | 1x |
preds_for_adjust <- preds_for_adjust_1[var_name] |
| 51 | 1x |
preds_for_adjust$pred_adjust <- preds_for_adjust_1$pred - pred_for_adjust_at_1st |
| 52 | ||
| 53 |
# Combine the above 2 to make quantile lines |
|
| 54 | 1x |
pred_ice_qtile <- merge( |
| 55 | 1x |
pred_at_1st_qtiles, |
| 56 | 1x |
preds_for_adjust |
| 57 |
) |
|
| 58 | 1x |
pred_ice_qtile$pred <- pred_ice_qtile$pred_at_1st + pred_ice_qtile$pred_adjust |
| 59 | ||
| 60 |
# Convert to probabilities if classification |
|
| 61 | 1x |
if (task == "classification") {
|
| 62 | ! |
pred_ice_qtile$pred <- inv.logit(pred_ice_qtile$pred) |
| 63 |
} |
|
| 64 | ||
| 65 | 1x |
return(pred_ice_qtile) |
| 66 |
} |
|
| 67 | ||
| 68 | ||
| 69 |
#' Create density curves |
|
| 70 |
#' @keywords internal |
|
| 71 |
#' |
|
| 72 |
#' @param new_data_with_pred Data with prediction to make density calculations on |
|
| 73 |
#' @param var_name String specifying which variable to calculate density |
|
| 74 |
#' @param vec_y_expand Optional values to expand y-axis |
|
| 75 |
#' @return Density data for plotting |
|
| 76 |
#' |
|
| 77 |
#' @details |
|
| 78 |
#' It should be amenable to any 1st-order model without interaction terms, |
|
| 79 |
#' however not implemented yet, such as handling `predict()` function output |
|
| 80 |
#' for binary endpoint |
|
| 81 |
#' |
|
| 82 |
calc_density <- function(new_data_with_pred, var_name, |
|
| 83 |
vec_y_expand = NULL) {
|
|
| 84 | 1x |
x_axis_range_data <- range(new_data_with_pred[[var_name]]) |
| 85 | 1x |
x_axis_range_density <- expand_range(x_axis_range_data, 0.5, 0.5) |
| 86 | ||
| 87 | ||
| 88 | 1x |
y_axis_range_data <- range(c(new_data_with_pred$pred, vec_y_expand)) |
| 89 | 1x |
y_axis_range_density <- expand_range(y_axis_range_data, 0.1, 0.1) |
| 90 | ||
| 91 | ||
| 92 |
# Estimate 2d density |
|
| 93 |
# Calculate bandwidth manually if MASS::bandwidth.nrd fails |
|
| 94 |
# (happens when most of data has same covariate, e.g. ==0. |
|
| 95 |
# MASS::bandwidth.nrd uses quantiles to calculate bandwidth) |
|
| 96 | 1x |
bandwidth_x <- MASS::bandwidth.nrd(new_data_with_pred[[var_name]]) |
| 97 | 1x |
bandwidth_y <- MASS::bandwidth.nrd(new_data_with_pred$pred) |
| 98 |
# If bandwidth is 0, set to 25% of range |
|
| 99 | 1x |
bandwidth_x <- ifelse(bandwidth_x == 0, diff(x_axis_range_data) * 0.25, bandwidth_x) |
| 100 | 1x |
bandwidth_y <- ifelse(bandwidth_y == 0, diff(y_axis_range_data) * 0.25, bandwidth_y) |
| 101 | ||
| 102 | 1x |
density_est <- MASS::kde2d( |
| 103 | 1x |
x = new_data_with_pred[[var_name]], |
| 104 | 1x |
y = new_data_with_pred$pred, |
| 105 | 1x |
n = 100, |
| 106 | 1x |
lims = c(x_axis_range_density, y_axis_range_density), |
| 107 | 1x |
h = c(bandwidth_x, bandwidth_y) |
| 108 |
) |
|
| 109 | ||
| 110 |
# Convert to data frame |
|
| 111 | 1x |
density_data <- expand.grid(x = density_est$x, y = density_est$y) |
| 112 | 1x |
density_data$z <- as.vector(density_est$z) |
| 113 | 1x |
density_data <- density_data[order(density_data$x, density_data$y), ] |
| 114 | ||
| 115 |
# Dummy data to make legend go down to 0.0, replace 1st row z value with 0 |
|
| 116 | 1x |
density_data[1, 3] <- 0 |
| 117 | ||
| 118 | 1x |
return(density_data) |
| 119 |
} |
|
| 120 | ||
| 121 | ||
| 122 |
#' Expand the range of values for density plot |
|
| 123 |
#' @keywords internal |
|
| 124 |
#' |
|
| 125 |
#' @param x numeric vector |
|
| 126 |
#' @param expand_left_side Fraction to expand on left hand side |
|
| 127 |
#' @param expand_right_side Fraction to expand on right hand side |
|
| 128 |
#' |
|
| 129 |
#' @return Vector of 2 values |
|
| 130 |
#' |
|
| 131 |
#' |
|
| 132 |
expand_range <- function(x, expand_left_side = 0.1, expand_right_side = 0.2, |
|
| 133 |
type = c("relative", "absolute")) {
|
|
| 134 | 2x |
type <- match.arg(type) |
| 135 | ||
| 136 | 2x |
if (type == "relative") {
|
| 137 | 2x |
expand_left_side <- expand_left_side * diff(range(x)) |
| 138 | 2x |
expand_right_side <- expand_right_side * diff(range(x)) |
| 139 |
} |
|
| 140 | ||
| 141 | 2x |
return(c(min(x) - expand_left_side, max(x) + expand_right_side)) |
| 142 |
} |
|
| 143 | ||
| 144 |
#' Density plots overlaid with ICE curves |
|
| 145 |
#' |
|
| 146 |
#' Create density plot for the data, overlaid with ICE curves at quantiles |
|
| 147 |
#' of the variable(s) of interest. |
|
| 148 |
#' |
|
| 149 |
#' |
|
| 150 |
#' @name g_density_ice |
|
| 151 |
NULL |
|
| 152 | ||
| 153 | ||
| 154 |
#' @rdname g_density_ice |
|
| 155 |
#' @export |
|
| 156 |
#' |
|
| 157 |
#' @param object Object of class sculpture (rough, detailed) |
|
| 158 |
#' @param new_data Data to make quantiles on |
|
| 159 |
#' @param var_name String specifying which variable to generate ICE |
|
| 160 |
#' @param var_label String (optional) specifying variable label (x label of the plot) |
|
| 161 |
#' @param qtiles Quantiles to generate ICE curves |
|
| 162 |
#' @param task Prediction task type (regression or classification) |
|
| 163 |
#' |
|
| 164 |
#' @return [g_density_ice_plot()]: ggplot object |
|
| 165 |
#' |
|
| 166 |
#' @details |
|
| 167 |
#' [g_density_ice_plot()] creates a density plot for a single variable. |
|
| 168 |
#' |
|
| 169 |
#' [g_density_ice_plot_list()] creates a list of density plots for multiple variables. |
|
| 170 |
#' |
|
| 171 |
#' These functions should be amenable to any 1st-order model without interaction terms, |
|
| 172 |
#' however not implemented yet, such as handling `predict()` function output |
|
| 173 |
#' for binary endpoint |
|
| 174 |
#' |
|
| 175 |
#' |
|
| 176 |
#' @examples |
|
| 177 |
#' \dontrun{
|
|
| 178 |
#' df <- mtcars |
|
| 179 |
#' df$cyl <- as.factor(df$cyl) |
|
| 180 |
#' model <- lm(hp ~ ., data = df) |
|
| 181 |
#' model_predict <- function(x) predict(model, newdata = x) |
|
| 182 |
#' covariates <- setdiff(colnames(df), "hp") |
|
| 183 |
#' pm <- sample_marginals(df[covariates], n = 50, seed = 5) |
|
| 184 |
#' |
|
| 185 |
#' rs <- sculpt_rough( |
|
| 186 |
#' dat = pm, |
|
| 187 |
#' model_predict_fun = model_predict, |
|
| 188 |
#' n_ice = 5, |
|
| 189 |
#' seed = 1, |
|
| 190 |
#' verbose = 0 |
|
| 191 |
#' ) |
|
| 192 |
#' |
|
| 193 |
#' g_density_ice_plot(rs, new_data = pm, var_name = "mpg") |
|
| 194 |
#' g_list <- g_density_ice_plot_list( |
|
| 195 |
#' rs, new_data = pm, var_names = c("mpg", "cyl", "disp", "drat")
|
|
| 196 |
#' ) |
|
| 197 |
#' grid::grid.draw(gridExtra::arrangeGrob(grobs = g_list)) |
|
| 198 |
#' } |
|
| 199 |
#' |
|
| 200 |
g_density_ice_plot <- function(object, new_data, var_name, var_label = NULL, |
|
| 201 |
qtiles = seq(0, 1, by = 0.1), |
|
| 202 |
task = c("regression", "classification")) {
|
|
| 203 | ! |
checkmate::assert_class(object, "sculpture") |
| 204 | ! |
new_data <- check_data(new_data) |
| 205 | ! |
checkmate::assert_string(var_name) |
| 206 | ! |
checkmate::assert_string(var_label, null.ok = TRUE) |
| 207 | ! |
checkmate::assert_numeric(qtiles, lower = 0, upper = 1) |
| 208 | ! |
checkmate::assert_character(task) |
| 209 | ||
| 210 | ! |
task <- match.arg(task) |
| 211 | ||
| 212 |
# https://cran.r-project.org/web/packages/data.table/vignettes/datatable-importing.html#globals |
|
| 213 | ! |
x <- y <- z <- pred <- qtile <- NULL # due to NSE notes in R CMD check |
| 214 | ||
| 215 | ! |
if (is.null(var_label)) {
|
| 216 | ! |
var_label <- var_name |
| 217 |
} |
|
| 218 | ||
| 219 | ! |
is_var_discrete <- !is.numeric(new_data[[var_name]]) |
| 220 | ||
| 221 | ! |
new_data_with_pred <- new_data |
| 222 | ! |
new_data_with_pred$pred <- predict(object, newdata = new_data) |
| 223 | ||
| 224 | ! |
if (task == "classification") {
|
| 225 | ! |
new_data_with_pred$pred <- inv.logit(new_data_with_pred$pred) |
| 226 |
} |
|
| 227 | ||
| 228 | ! |
pred_ice_qtile <- calc_ice_quantile( |
| 229 | ! |
object, new_data, |
| 230 | ! |
var_name = var_name, qtiles = qtiles, task = task |
| 231 |
) |
|
| 232 | ||
| 233 | ! |
if (is_var_discrete) {
|
| 234 | ! |
new_data_with_pred[[var_name]] <- as.numeric(as.factor(new_data_with_pred[[var_name]])) |
| 235 | ! |
pred_ice_qtile[[var_name]] <- as.numeric(as.factor(pred_ice_qtile[[var_name]])) |
| 236 |
} |
|
| 237 | ||
| 238 | ! |
density_data <- calc_density( |
| 239 | ! |
new_data_with_pred, |
| 240 | ! |
var_name = var_name, |
| 241 | ! |
vec_y_expand = pred_ice_qtile$pred |
| 242 |
) |
|
| 243 | ||
| 244 | ! |
x_axis_range_data <- range(new_data_with_pred[[var_name]]) |
| 245 | ! |
if (is_var_discrete) {
|
| 246 | ! |
x_axis_range_plot <- expand_range(x_axis_range_data, 0.3, 0.3, type = "absolute") |
| 247 |
} else {
|
|
| 248 | ! |
x_axis_range_plot <- expand_range(x_axis_range_data, 0, 0.15) |
| 249 |
} |
|
| 250 | ! |
y_axis_range_plot <- range(c(new_data_with_pred$pred, pred_ice_qtile$pred)) |
| 251 | ||
| 252 | ! |
ggrepel_data <- pred_ice_qtile[ |
| 253 | ! |
pred_ice_qtile$qtile %in% c(0, 0.5, 1) & |
| 254 | ! |
pred_ice_qtile[[var_name]] == max(pred_ice_qtile[[var_name]]), |
| 255 |
] |
|
| 256 | ||
| 257 | ! |
density_plot <- ggplot(density_data, aes(x = x, y = y)) + |
| 258 | ! |
geom_raster(aes(fill = z), interpolate = TRUE) + |
| 259 | ! |
labs(x = var_label, y = "Predicted Value", fill = "Density") + |
| 260 | ! |
scale_fill_viridis_c() + |
| 261 | ! |
coord_cartesian(xlim = x_axis_range_plot, ylim = y_axis_range_plot) + |
| 262 | ! |
theme( |
| 263 | ! |
panel.ontop = TRUE, |
| 264 | ! |
panel.background = element_rect(color = NA, fill = NA), |
| 265 | ! |
panel.grid.major = element_line(color = grDevices::rgb(1, 1, 1, 0.1)), |
| 266 | ! |
panel.grid.minor = element_line(color = grDevices::rgb(1, 1, 1, 0.1)) |
| 267 |
) + |
|
| 268 | ! |
geom_line( |
| 269 | ! |
data = pred_ice_qtile, |
| 270 | ! |
aes(x = .data[[var_name]], y = pred, group = qtile), linewidth = 0.3, |
| 271 | ! |
color = "grey70" |
| 272 |
) + |
|
| 273 | ! |
geom_line( |
| 274 | ! |
data = pred_ice_qtile[pred_ice_qtile$qtile %in% c(0, 0.5, 1), ], |
| 275 | ! |
aes(x = .data[[var_name]], y = pred, group = qtile), linewidth = 0.7, |
| 276 | ! |
color = "grey70" |
| 277 |
) + |
|
| 278 | ! |
ggrepel::geom_text_repel( |
| 279 | ! |
data = ggrepel_data, |
| 280 | ! |
aes(x = .data[[var_name]], y = pred, label = paste0(round(qtile * 100), "%")), |
| 281 | ! |
color = "grey70", |
| 282 | ! |
box.padding = unit(0.25, "lines"), |
| 283 | ! |
point.padding = unit(0.25, "lines"), |
| 284 | ! |
segment.linetype = "dotted", |
| 285 | ! |
min.segment.length = unit(0, "lines"), |
| 286 | ! |
nudge_x = diff(range(x_axis_range_plot)) / 6, |
| 287 | ! |
direction = "y", hjust = "right" |
| 288 |
) |
|
| 289 | ||
| 290 | ||
| 291 | ! |
if (is_var_discrete) {
|
| 292 | ! |
levels <- levels(as.factor(new_data[[var_name]])) |
| 293 | ! |
density_plot <- density_plot + |
| 294 | ! |
scale_x_continuous( |
| 295 | ! |
breaks = seq_len(length(levels)), |
| 296 | ! |
labels = levels, |
| 297 | ! |
minor_breaks = NULL |
| 298 |
) |
|
| 299 |
} |
|
| 300 | ||
| 301 | ! |
return(density_plot) |
| 302 |
} |
|
| 303 | ||
| 304 | ||
| 305 |
#' @rdname g_density_ice |
|
| 306 |
#' @export |
|
| 307 |
#' |
|
| 308 |
#' @param var_names Vector of strings specifying which variables to generate ICE |
|
| 309 |
#' @param var_labels Named vector of strings specifying variable labels. |
|
| 310 |
#' |
|
| 311 |
#' @return [g_density_ice_plot_list()]: list of ggplot objects |
|
| 312 |
#' |
|
| 313 |
g_density_ice_plot_list <- function(object, new_data, var_names, var_labels = NULL, |
|
| 314 |
qtiles = seq(0, 1, by = 0.1), |
|
| 315 |
task = c("regression", "classification")) {
|
|
| 316 | ! |
checkmate::assert_class(object, "sculpture") |
| 317 | ! |
new_data <- check_data(new_data) |
| 318 | ! |
checkmate::assert_character(var_names) |
| 319 | ! |
checkmate::assert_character(var_labels, null.ok = TRUE) |
| 320 | ! |
checkmate::assert_numeric(qtiles, lower = 0, upper = 1) |
| 321 | ! |
checkmate::assert_character(task) |
| 322 | ||
| 323 | ! |
task <- match.arg(task) |
| 324 | ||
| 325 | ! |
out <- vector("list", length(var_names))
|
| 326 | ! |
names(out) <- var_names |
| 327 | ||
| 328 | ! |
for (var_name in var_names) {
|
| 329 | ! |
out[[var_name]] <- |
| 330 | ! |
g_density_ice_plot(object, new_data, var_name, var_labels[var_name], qtiles, task) |
| 331 |
} |
|
| 332 | ||
| 333 | ! |
return(out) |
| 334 |
} |
| 1 |
#' Set and end parallel computation |
|
| 2 |
#' |
|
| 3 |
#' @param num_cores (`integer`) Number of cores. |
|
| 4 |
#' @param cluster_type (`character`) Type of cluster. One of `c("fork", "psock")`.
|
|
| 5 |
#' |
|
| 6 |
#' @export |
|
| 7 |
#' @examples |
|
| 8 |
#' \dontrun{
|
|
| 9 |
#' parallel_set(num_cores = 2) |
|
| 10 |
#' # now the code will run on parallel with 2 cores |
|
| 11 |
#' parallel_end() |
|
| 12 |
#' # now the code will run sequentially |
|
| 13 |
#' } |
|
| 14 |
parallel_set <- function(num_cores = 10, cluster_type = "fork") {
|
|
| 15 | ! |
checkmate::assert_integerish(num_cores, lower = 1, any.missing = FALSE, len = 1) |
| 16 | ! |
cluster_type <- match.arg(cluster_type, choices = c("fork", "psock"))
|
| 17 | ||
| 18 | ! |
parallel_end() |
| 19 | ||
| 20 | ! |
if (cluster_type == "fork") {
|
| 21 | ! |
cl <- parallel::makeForkCluster(num_cores) |
| 22 |
} else {
|
|
| 23 | ! |
cl <- parallel::makePSOCKcluster(num_cores) |
| 24 |
} |
|
| 25 | ||
| 26 | ! |
doParallel::registerDoParallel(cl) |
| 27 | ||
| 28 | ! |
message(paste("Using", foreach::getDoParWorkers(), "cores")) # should be == num_cores
|
| 29 |
} |
|
| 30 | ||
| 31 |
#' @rdname parallel_set |
|
| 32 |
#' @export |
|
| 33 |
parallel_end <- function() {
|
|
| 34 | ! |
if (foreach::getDoParRegistered()) {
|
| 35 | ! |
foreach::registerDoSEQ() |
| 36 |
} |
|
| 37 |
} |
|
| 38 | ||
| 39 | ||
| 40 |
define_foreach_operand <- function(allow_par = FALSE) {
|
|
| 41 | 16x |
if (foreach::getDoParRegistered() && allow_par) {
|
| 42 | ! |
foreach::`%dopar%` |
| 43 |
} else {
|
|
| 44 | 16x |
foreach::`%do%` |
| 45 |
} |
|
| 46 |
} |
| 1 |
g_pdp <- function(dt, pdp_plot_sample, feat_labels) {
|
|
| 2 |
# pdp_plot_sample ensures faster rendering |
|
| 3 | ! |
if (pdp_plot_sample && nrow(dt) > 4e4) {
|
| 4 | ! |
set.seed(101) |
| 5 | ! |
g <- ggplot( |
| 6 | ! |
data = dt[sample(nrow(dt), 4e4), ], |
| 7 | ! |
mapping = aes(y = factor(.data$feature, levels = rev(levels(.data$feature))), x = .data$pdp_c) |
| 8 |
) + |
|
| 9 | ! |
geom_jitter(shape = 16, size = 1.5, alpha = 0.7, position = position_jitter(seed = 1)) |
| 10 |
} else {
|
|
| 11 | ! |
g <- ggplot( |
| 12 | ! |
data = dt, |
| 13 | ! |
mapping = aes(y = factor(.data$feature, levels = rev(levels(.data$feature))), x = .data$pdp_c) |
| 14 |
) + |
|
| 15 | ! |
geom_jitter(shape = 16, size = 1.5, alpha = 0.2, position = position_jitter(seed = 1)) |
| 16 |
} |
|
| 17 | ||
| 18 | ! |
g <- g + |
| 19 | ! |
scale_y_discrete(labels = function(x) feat_labels[x]) + |
| 20 | ! |
labs(x = "Feature Score", y = "Feature") + |
| 21 | ! |
theme_bw() |
| 22 | ||
| 23 | ! |
return(g) |
| 24 |
} |
|
| 25 | ||
| 26 |
g_imp_abs <- function(dat_var, show_pdp_plot, textsize) {
|
|
| 27 | ! |
nudge_x <- max(dat_var$variance) / 5 |
| 28 | ! |
dat_var$variance_vs_top <- dat_var$variance / max(dat_var$variance) |
| 29 | ||
| 30 | ! |
g <- ggplot( |
| 31 | ! |
dat_var, |
| 32 | ! |
aes(y = factor(.data$feature, levels = rev(levels(.data$feature))), x = .data$variance) |
| 33 |
) + |
|
| 34 | ! |
geom_point() + |
| 35 | ! |
geom_text( |
| 36 | ! |
aes( |
| 37 | ! |
x = ifelse( |
| 38 | ! |
.data$variance_vs_top > 0.5, |
| 39 | ! |
.data$variance - 2 * nudge_x, |
| 40 | ! |
.data$variance |
| 41 |
), |
|
| 42 | ! |
label = format(round(.data$variance, 3), nsmall = 3, digits = 3) |
| 43 |
), |
|
| 44 | ! |
nudge_x = nudge_x, |
| 45 | ! |
size = round(textsize / 3) |
| 46 |
) + |
|
| 47 | ! |
labs( |
| 48 | ! |
x = "Direct Variable Importance", |
| 49 | ! |
y = ifelse(show_pdp_plot, "", "Feature") |
| 50 |
) + |
|
| 51 | ! |
theme_bw() |
| 52 | ! |
return(g) |
| 53 |
} |
|
| 54 | ||
| 55 |
g_imp_norm <- function(dat_var, show_pdp_plot, textsize) {
|
|
| 56 | ! |
g <- ggplot( |
| 57 | ! |
dat_var, |
| 58 | ! |
aes(y = factor(.data$feature, levels = rev(levels(.data$feature))), x = .data$ratio) |
| 59 |
) + |
|
| 60 | ! |
geom_point() + |
| 61 | ! |
geom_text( |
| 62 | ! |
aes( |
| 63 | ! |
x = ifelse(.data$ratio > 0.75, .data$ratio - 0.4, .data$ratio), |
| 64 | ! |
label = sprintf("%.1f%%", round(.data$ratio * 100, 1))
|
| 65 |
), |
|
| 66 | ! |
nudge_x = 0.2, |
| 67 | ! |
size = round(textsize / 3) |
| 68 |
) + |
|
| 69 | ! |
xlim(c(0, 1)) + |
| 70 | ! |
labs( |
| 71 | ! |
x = "Direct Variable Importance", |
| 72 | ! |
y = ifelse(show_pdp_plot, "", "Feature") |
| 73 |
) + |
|
| 74 | ! |
theme_bw() |
| 75 | ! |
return(g) |
| 76 |
} |
|
| 77 | ||
| 78 |
g_imp_ice <- function(vars, vars_mean) {
|
|
| 79 | 1x |
g <- ggplot() + |
| 80 | 1x |
geom_point( |
| 81 | 1x |
aes(y = factor(.data$feature, levels = rev(levels(.data$feature))), x = .data$var_y), |
| 82 | 1x |
data = vars, |
| 83 | 1x |
size = 1, |
| 84 | 1x |
colour = "gray50" |
| 85 |
) + |
|
| 86 | 1x |
geom_point( |
| 87 | 1x |
aes(y = factor(.data$feature, levels = rev(levels(.data$feature))), x = .data$mean_var_y), |
| 88 | 1x |
data = vars_mean, |
| 89 | 1x |
size = 2, |
| 90 | 1x |
colour = "black" |
| 91 |
) + |
|
| 92 | 1x |
labs( |
| 93 | 1x |
x = "Direct Variable Importance", |
| 94 | 1x |
y = "Feature" |
| 95 |
) + |
|
| 96 | 1x |
theme_bw() |
| 97 | 1x |
return(g) |
| 98 |
} |
|
| 99 | ||
| 100 |
g_cumulR2 <- function(dat_R2_cumul, textsize) {
|
|
| 101 | 1x |
g <- ggplot( |
| 102 | 1x |
dat_R2_cumul, |
| 103 | 1x |
aes(y = factor(.data$feature, levels = rev(levels(.data$feature))), x = round(.data$R2, 4)) |
| 104 |
) + |
|
| 105 | 1x |
geom_point() + |
| 106 | 1x |
geom_text( |
| 107 | 1x |
aes( |
| 108 | 1x |
x = ifelse(.data$R2 < 0.25, .data$R2 + 0.4, .data$R2), |
| 109 | 1x |
label = sprintf("%.1f%%", round(.data$R2 * 100, 1))
|
| 110 |
), |
|
| 111 | 1x |
nudge_x = -0.2, |
| 112 | 1x |
size = round(textsize / 3) |
| 113 |
) + |
|
| 114 | 1x |
xlim(c(0, 1)) + |
| 115 | 1x |
labs( |
| 116 | 1x |
x = expression("Cumulative Approximation " * R^2),
|
| 117 | 1x |
y = "" |
| 118 |
) + |
|
| 119 | 1x |
theme_bw() + |
| 120 | 1x |
theme(axis.ticks.y = element_blank(), axis.text.y = element_blank()) |
| 121 | 1x |
return(g) |
| 122 |
} |