commit4eff300b6b36608e58a964fa9ee048c3d28a43d7parent31edf6887b01cac0cdc5e412a8190c94ff8156d5Author:Eamon Caddigan <eamon.caddigan@gmail.com>Date:Sun, 8 May 2022 23:46:23 -0400 Code for centering/scaling data responsibly for ML contextsDiffstat:

A | scale_and_apply.Rmd | | | 87 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |

1 file changed, 87 insertions(+), 0 deletions(-)diff --git a/scale_and_apply.Rmd b/scale_and_apply.Rmd@@ -0,0 +1,87 @@ +--- +title: "R Notebook" +output: html_notebook +--- + +# Scale and Apply + +When fitting regression models to data (other models too, but regression especially), it's useful to center variables to a mean of 0 and (depending on the type of variable) scale them to a standard deviation of 1. I'm feeling lazy and imprecice, so I'll flip between calling this operation "normalizing" or "scaling". + +When performing machine learning, it's best practice to withold the testing set and perform the normalization only on the training set, and then apply the training set features' mean and sd to those in the testing set. This is one of those things that I was taught by my grad advisor and have taken to heart, but I don't think everybody is so careful about this. + +Regardless, I feel like I write the same code again and again, so I'm going to stash an implementation in this notebook and put it on Git for future use. + +There are two steps: `find_norm` is called on a data frame and returns a named list. `apply_norm` is called with the aforementioned named list and a data frame (which can be the same one used in `find_norm` or a new with the same column names) and returns a normalized data frame. + +* Numeric columns in the range [0, 1] (integer or double) are centered but not scaled per my recollection of a suggestion from Andrew Gelman on handling binary variables +* All other numeric columns are centered and scaled +* Non-numeric columns (including factors) are left alone; convert them to numeric first if you want this to do anything + +`mtcars` is a good demo; we'll ignore row names, and we have a couple variables in the range [0, 1]. We'll convert the number of cylinders to a factor. + +```{r update_cars} +mtcars$cyl <- as.factor(mtcars$cyl) +``` + +We'll split the data into a training and test set. + +```{r} +set.seed(414726326) +all_id <- sample(seq(nrow(mtcars))) +test_cars <- mtcars[all_id[1:6], ] +train_cars <- mtcars[all_id[7:length(all_id)], ] +``` + +Define `find_norm` and run it on our "training data" + +```{r find_norm} +find_norm <- function(dat) { + center_and_scale <- function(x) { + if (is.numeric(x)) { + x_mean <- mean(x, na.rm = TRUE) + if (min(x) >= 0 && max(x) <= 1) + c(center = x_mean) + else + c(center = x_mean, scale = sd(x, na.rm = TRUE)) + } else { + NULL + } + } + + lapply(dat, center_and_scale) +} + +norm_list <- find_norm(train_cars) +``` + +Now we'll apply it to the same data we used to find the normalization values. For numeric data outside the range [0, 1], this is the same as applying `base::scale(x, center = T, scale = T)`. + +```{r apply_norm} +# Any columns present in `dat` and missing in `norm_list` will be ignored. Any +# columns missing from `dat` but present in `norm_list` will throw an error; +# I'll leave it to downstream me to deal with that. +apply_norm <- function(dat, norm_list) { + for (i in seq_along(norm_list)) { + col_name <- names(norm_list)[i] + if (!is.null(norm_list[[i]])) { + dat[[col_name]] <- dat[[col_name]] - norm_list[[i]]['center'] + if (!is.na(norm_list[[i]]['scale'])) { + dat[[col_name]] <- dat[[col_name]] / norm_list[[i]]['scale'] + } + } + } + dat +} + +train_cars_normed <- apply_norm(train_cars, norm_list) +zapsmall(c(mean(train_cars_normed$mpg), sd(train_cars_normed$mpg))) +``` + +We can apply the same normalization values to a different set of data... + +```{r apply_norm_different_data} +test_cars_normed <- apply_norm(test_cars, norm_list) +zapsmall(c(mean(test_cars_normed$mpg), sd(test_cars_normed$mpg))) +``` + +And now we don't have data leakage from our training data to our testing data!