<- function(df, impute, remove_outliers, mod) {
fit_model if (impute) {
<- some_imputation_function(df)
df
}
if (remove_outliers) {
<- function_for_removing_outliers(df)
df
}
lm(mod, data = df)
}
Setting the Stage
My number one use case for writing functions and iteration / looping is to perform some exploration or modeling repeatedly for different “tweaked” versions. For example, our broad goal might be to fit a linear regression model to our data. However, there are often multiple choices that we have to make in practice:
- Keep missing values or fill them in (imputation)?
- Filter out outliers in one or more variables?
We can map these choices to arguments in a custom model-fitting function:
impute
: TRUE or FALSEremove_outliers
: TRUE or FALSE
A function that implements the analysis and allows for variation in these choices:
Helper Functions
Goal 1: Write a function that removes outliers in a dataset. The user should be able to supply the dataset, the variables to remove outliers from, and a threshold on the number of SDs away from the mean used to define outliers.
<- function(data, ..., sd_thresh = 3) {
remove_outliers <- as.data.table(data)
dt
#Capture column names
<- as.character(substitute(list(...)))[-1]
cols
#Check which are numeric
<- sapply(dt[, ..cols], is.numeric)
is_numeric <- cols[is_numeric]
numeric_cols <- cols[!is_numeric]
non_numeric
if (length(non_numeric) > 0) {
warning("The following columns are not numeric and will be skipped: ",
paste(non_numeric, collapse = ", "))
}
if (length(numeric_cols) == 0) {
return(dt) #If no numeric columns to filter, just return the data table
}
#Compute z-scores for all numeric columns at once
<- dt[, lapply(.SD, function(col) abs(scale(col))), .SDcols = numeric_cols]
z_scores
#Compute mask: rows where all z-scores are within threshold
<- rowSums(z_scores > sd_thresh) == 0
keep_rows
return(dt[keep_rows])
}
Testing My Function!
## Testing how your function handles multiple input variables
remove_outliers(diamonds,
price,
x,
y, z)
carat cut color clarity depth table price x y z
<num> <ord> <ord> <ord> <num> <num> <int> <num> <num> <num>
1: 0.23 Ideal E SI2 61.5 55 326 3.95 3.98 2.43
2: 0.21 Premium E SI1 59.8 61 326 3.89 3.84 2.31
3: 0.23 Good E VS1 56.9 65 327 4.05 4.07 2.31
4: 0.29 Premium I VS2 62.4 58 334 4.20 4.23 2.63
5: 0.31 Good J SI2 63.3 58 335 4.34 4.35 2.75
---
52685: 0.72 Ideal D SI1 60.8 57 2757 5.75 5.76 3.50
52686: 0.72 Good D SI1 63.1 55 2757 5.69 5.75 3.61
52687: 0.70 Very Good D SI1 62.8 60 2757 5.66 5.68 3.56
52688: 0.86 Premium H SI2 61.0 58 2757 6.15 6.12 3.74
52689: 0.75 Ideal D SI2 62.2 55 2757 5.83 5.87 3.64
## Testing how your function handles an input that isn't numeric
remove_outliers(diamonds,
price, color)
Warning in remove_outliers(diamonds, price, color): The following columns are
not numeric and will be skipped: color
carat cut color clarity depth table price x y z
<num> <ord> <ord> <ord> <num> <num> <int> <num> <num> <num>
1: 0.23 Ideal E SI2 61.5 55 326 3.95 3.98 2.43
2: 0.21 Premium E SI1 59.8 61 326 3.89 3.84 2.31
3: 0.23 Good E VS1 56.9 65 327 4.05 4.07 2.31
4: 0.29 Premium I VS2 62.4 58 334 4.20 4.23 2.63
5: 0.31 Good J SI2 63.3 58 335 4.34 4.35 2.75
---
52730: 0.72 Ideal D SI1 60.8 57 2757 5.75 5.76 3.50
52731: 0.72 Good D SI1 63.1 55 2757 5.69 5.75 3.61
52732: 0.70 Very Good D SI1 62.8 60 2757 5.66 5.68 3.56
52733: 0.86 Premium H SI2 61.0 58 2757 6.15 6.12 3.74
52734: 0.75 Ideal D SI2 62.2 55 2757 5.83 5.87 3.64
## Testing how your function handles a non-default sd_thresh
remove_outliers(diamonds,
price,
x,
y,
z, sd_thresh = 2)
carat cut color clarity depth table price x y z
<num> <ord> <ord> <ord> <num> <num> <int> <num> <num> <num>
1: 0.23 Ideal E SI2 61.5 55 326 3.95 3.98 2.43
2: 0.21 Premium E SI1 59.8 61 326 3.89 3.84 2.31
3: 0.23 Good E VS1 56.9 65 327 4.05 4.07 2.31
4: 0.29 Premium I VS2 62.4 58 334 4.20 4.23 2.63
5: 0.31 Good J SI2 63.3 58 335 4.34 4.35 2.75
---
50095: 0.72 Ideal D SI1 60.8 57 2757 5.75 5.76 3.50
50096: 0.72 Good D SI1 63.1 55 2757 5.69 5.75 3.61
50097: 0.70 Very Good D SI1 62.8 60 2757 5.66 5.68 3.56
50098: 0.86 Premium H SI2 61.0 58 2757 6.15 6.12 3.74
50099: 0.75 Ideal D SI2 62.2 55 2757 5.83 5.87 3.64
Goal 2: Write a function that imputes missing values for numeric variables in a dataset. The user should be able to supply the dataset, the variables to impute values for, and a function to use when imputing.
<- function(data, ..., impute_fun = mean) {
impute_missing_dt #Capture column names
<- as.character(substitute(list(...)))[-1]
cols
#Check for columns which are non-numeric
<- cols[!sapply(data[, cols, drop = FALSE], is.numeric)]
non_numeric
if (length(non_numeric) > 0) {
warning("The following columns are not numeric and will be skipped: ",
paste(non_numeric, collapse = ", "))
}
#Filter only numeric columns
<- setdiff(cols, non_numeric)
numeric_cols
|>
data mutate(across(
.cols = all_of(numeric_cols),
.fns = ~ replace_na(.x, impute_fun(.x, na.rm = TRUE)) #Impute using the specified function
)) }
Testing My Function!
## Testing how your function handles multiple input variables
impute_missing_dt(nycflights13::flights,
arr_delay, dep_delay)
# A tibble: 336,776 × 19
year month day dep_time sched_dep_time dep_delay arr_time sched_arr_time
<int> <int> <int> <int> <int> <dbl> <int> <int>
1 2013 1 1 517 515 2 830 819
2 2013 1 1 533 529 4 850 830
3 2013 1 1 542 540 2 923 850
4 2013 1 1 544 545 -1 1004 1022
5 2013 1 1 554 600 -6 812 837
6 2013 1 1 554 558 -4 740 728
7 2013 1 1 555 600 -5 913 854
8 2013 1 1 557 600 -3 709 723
9 2013 1 1 557 600 -3 838 846
10 2013 1 1 558 600 -2 753 745
# ℹ 336,766 more rows
# ℹ 11 more variables: arr_delay <dbl>, carrier <chr>, flight <int>,
# tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>,
# hour <dbl>, minute <dbl>, time_hour <dttm>
## Testing how your function handles an input that isn't numeric
impute_missing_dt(nycflights13::flights,
arr_delay, carrier)
Warning in impute_missing_dt(nycflights13::flights, arr_delay, carrier): The
following columns are not numeric and will be skipped: carrier
# A tibble: 336,776 × 19
year month day dep_time sched_dep_time dep_delay arr_time sched_arr_time
<int> <int> <int> <int> <int> <dbl> <int> <int>
1 2013 1 1 517 515 2 830 819
2 2013 1 1 533 529 4 850 830
3 2013 1 1 542 540 2 923 850
4 2013 1 1 544 545 -1 1004 1022
5 2013 1 1 554 600 -6 812 837
6 2013 1 1 554 558 -4 740 728
7 2013 1 1 555 600 -5 913 854
8 2013 1 1 557 600 -3 709 723
9 2013 1 1 557 600 -3 838 846
10 2013 1 1 558 600 -2 753 745
# ℹ 336,766 more rows
# ℹ 11 more variables: arr_delay <dbl>, carrier <chr>, flight <int>,
# tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>,
# hour <dbl>, minute <dbl>, time_hour <dttm>
## Testing how your function handles a non-default impute_fun
impute_missing_dt(nycflights13::flights,
arr_delay,
dep_delay, impute_fun = median)
# A tibble: 336,776 × 19
year month day dep_time sched_dep_time dep_delay arr_time sched_arr_time
<int> <int> <int> <int> <int> <dbl> <int> <int>
1 2013 1 1 517 515 2 830 819
2 2013 1 1 533 529 4 850 830
3 2013 1 1 542 540 2 923 850
4 2013 1 1 544 545 -1 1004 1022
5 2013 1 1 554 600 -6 812 837
6 2013 1 1 554 558 -4 740 728
7 2013 1 1 555 600 -5 913 854
8 2013 1 1 557 600 -3 709 723
9 2013 1 1 557 600 -3 838 846
10 2013 1 1 558 600 -2 753 745
# ℹ 336,766 more rows
# ℹ 11 more variables: arr_delay <dbl>, carrier <chr>, flight <int>,
# tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>,
# hour <dbl>, minute <dbl>, time_hour <dttm>
Primary Function
Goal 3: Write a fit_model()
function that fits a specified linear regression model for a specified dataset. The function should:
- allow the user to specify if outliers should be removed (
TRUE
orFALSE
) - allow the user to specify if missing observations should be imputed (
TRUE
orFALSE
)
If either option is TRUE
, the function should call my remove_outliers()
or impute_missing()
functions to modify the data before the regression model is fit.
<- function(df, impute_missing = FALSE, remove_outliers = FALSE, mod_formula, ...) {
fit_model #Check if need imputation
if (impute_missing) {
<- impute_missing_dt(df, ...) #Call my imputation function
df
}
#Check if need outlier removal
if (remove_outliers) {
<- remove_outliers(df, ...) #Call my outlier removal function
df
}
#Fit the linear regression model
<- lm(mod_formula, data = df)
model
#Return the model
return(model)
}
Testing My Function!
fit_model(
diamonds,mod_formula = price ~ carat + cut,
remove_outliers = TRUE,
impute_missing = TRUE,
price,
carat )
Call:
lm(formula = mod_formula, data = df)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-2460.16 7526.96 1059.65 -410.54 295.80 82.62
Iteration
In the diamonds
dataset, I want to understand the relationship between price
and size (carat
). Specifically, I want to explore variation along two choices:
The variables included in the model. I’ll explore 3 sets of variables:
- No further variables (just
price
andcarat
) - Adjusting for
cut
- Adjusting for
cut
andclarity
- Adjusting for
cut
,clarity
, andcolor
- No further variables (just
Whether or not to impute missing values
Whether or not to remove outliers in the
carat
variable (I’ll define outliers as cases whosecarat
is over 3 SDs away from the mean).
Parameters
First, I need to define the set of parameters I want to iterate the fit_model()
function over. The tidyr
package has a useful function called crossing()
that is useful for generating argument combinations. For each argument, I specify all possible values for that argument and crossing()
generates all combinations.
<- crossing(
df_arg_combos impute = c(TRUE, FALSE),
remove_outliers = c(TRUE, FALSE),
mod = c(y ~ x1,
~ x1 + x2)
y
) df_arg_combos
Goal 4: Use crossing()
to create the data frame of argument combinations for our analyses.
<- crossing(
df_arg_combos impute_missing = c(TRUE, FALSE),
remove_outliers = c(TRUE, FALSE),
mod_formula = list(
~ carat, #No additional variables
price ~ carat + cut, #Adjusting for cut
price ~ carat + cut + clarity, #Adjusting for cut and clarity
price ~ carat + cut + clarity + color #Adjusting for all variables
price
) )
Iterating Over the Parameters
I’ve arrived at the final step!
Goal 5: Use pmap()
from purrr
to apply the fit_model()
function to every combination of arguments from `diamonds.
#Apply fit_model to all combinations of arguments using pmap
<- pmap(df_arg_combos, fit_model, df = diamonds, price, carat)
results
#Fitted models for each combination of arguments
results
[[1]]
Call:
lm(formula = mod_formula, data = df)
Coefficients:
(Intercept) carat
-2256 7756
[[2]]
Call:
lm(formula = mod_formula, data = df)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-2701.38 7871.08 1239.80 -528.60 367.91 74.59
[[3]]
Call:
lm(formula = mod_formula, data = df)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-3187.540 8472.026 713.804 -334.503 188.482 1.663
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
4011.681 -1821.922 917.658 -430.047 257.141 26.909
clarity^7
186.742
[[4]]
Call:
lm(formula = mod_formula, data = df)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-3710.603 8886.129 698.907 -327.686 180.565 -1.207
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
4217.535 -1832.406 923.273 -361.995 216.616 2.105
clarity^7 color.L color.Q color.C color^4 color^5
110.340 -1910.288 -627.954 -171.960 21.678 -85.943
color^6
-49.986
[[5]]
Call:
lm(formula = mod_formula, data = df)
Coefficients:
(Intercept) carat
-2067 7411
[[6]]
Call:
lm(formula = mod_formula, data = df)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-2460.16 7526.96 1059.65 -410.54 295.80 82.62
[[7]]
Call:
lm(formula = mod_formula, data = df)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-2897.60 8118.13 618.65 -268.13 148.57 8.10
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
3463.17 -1505.52 609.45 -286.81 158.48 71.83
clarity^7
183.33
[[8]]
Call:
lm(formula = mod_formula, data = df)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-3376.856 8513.377 607.668 -263.835 142.453 5.506
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
3654.785 -1503.171 612.049 -225.568 124.920 47.356
clarity^7 color.L color.Q color.C color^4 color^5
112.889 -1714.316 -539.239 -127.282 43.649 -60.866
color^6
-49.333
[[9]]
Call:
lm(formula = mod_formula, data = df)
Coefficients:
(Intercept) carat
-2256 7756
[[10]]
Call:
lm(formula = mod_formula, data = df)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-2701.38 7871.08 1239.80 -528.60 367.91 74.59
[[11]]
Call:
lm(formula = mod_formula, data = df)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-3187.540 8472.026 713.804 -334.503 188.482 1.663
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
4011.681 -1821.922 917.658 -430.047 257.141 26.909
clarity^7
186.742
[[12]]
Call:
lm(formula = mod_formula, data = df)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-3710.603 8886.129 698.907 -327.686 180.565 -1.207
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
4217.535 -1832.406 923.273 -361.995 216.616 2.105
clarity^7 color.L color.Q color.C color^4 color^5
110.340 -1910.288 -627.954 -171.960 21.678 -85.943
color^6
-49.986
[[13]]
Call:
lm(formula = mod_formula, data = df)
Coefficients:
(Intercept) carat
-2067 7411
[[14]]
Call:
lm(formula = mod_formula, data = df)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-2460.16 7526.96 1059.65 -410.54 295.80 82.62
[[15]]
Call:
lm(formula = mod_formula, data = df)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-2897.60 8118.13 618.65 -268.13 148.57 8.10
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
3463.17 -1505.52 609.45 -286.81 158.48 71.83
clarity^7
183.33
[[16]]
Call:
lm(formula = mod_formula, data = df)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-3376.856 8513.377 607.668 -263.835 142.453 5.506
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
3654.785 -1503.171 612.049 -225.568 124.920 47.356
clarity^7 color.L color.Q color.C color^4 color^5
112.889 -1714.316 -539.239 -127.282 43.649 -60.866
color^6
-49.333