Code Efficiency

I really wanted to get a little better with functions in R, so I decided to study up on some best-practices. Then, I tried to accomplish some tasks, making functions as efficient as I could. This was one of my first times writing functions in R, so I don’t think they were the absolute best, but they serve as a great reference point for how I keep improving!
Functions
Published

May 12, 2025

View the code on GitHub

Setting the Stage

My number one use case for writing functions and iteration / looping is to perform some exploration or modeling repeatedly for different “tweaked” versions. For example, our broad goal might be to fit a linear regression model to our data. However, there are often multiple choices that we have to make in practice:

  • Keep missing values or fill them in (imputation)?
  • Filter out outliers in one or more variables?

We can map these choices to arguments in a custom model-fitting function:

  • impute: TRUE or FALSE
  • remove_outliers: TRUE or FALSE

A function that implements the analysis and allows for variation in these choices:

fit_model <- function(df, impute, remove_outliers, mod) {
    if (impute) {
        df <- some_imputation_function(df)
    }
    
    if (remove_outliers) {
        df <- function_for_removing_outliers(df)
    }
    
    lm(mod, data = df)
}

Helper Functions

Goal 1: Write a function that removes outliers in a dataset. The user should be able to supply the dataset, the variables to remove outliers from, and a threshold on the number of SDs away from the mean used to define outliers.

remove_outliers <- function(data, ..., sd_thresh = 3) {
  dt <- as.data.table(data)
  
  #Capture column names
  cols <- as.character(substitute(list(...)))[-1]
  
  #Check which are numeric
  is_numeric <- sapply(dt[, ..cols], is.numeric)
  numeric_cols <- cols[is_numeric]
  non_numeric <- cols[!is_numeric]
  
  if (length(non_numeric) > 0) {
    warning("The following columns are not numeric and will be skipped: ",
            paste(non_numeric, collapse = ", "))
  }
  
  if (length(numeric_cols) == 0) {
    return(dt)  #If no numeric columns to filter, just return the data table
  }
  
  #Compute z-scores for all numeric columns at once
  z_scores <- dt[, lapply(.SD, function(col) abs(scale(col))), .SDcols = numeric_cols]
  
  #Compute mask: rows where all z-scores are within threshold
  keep_rows <- rowSums(z_scores > sd_thresh) == 0
  
  return(dt[keep_rows])
}

Testing My Function!

## Testing how your function handles multiple input variables
remove_outliers(diamonds, 
                price, 
                x, 
                y, 
                z)
       carat       cut color clarity depth table price     x     y     z
       <num>     <ord> <ord>   <ord> <num> <num> <int> <num> <num> <num>
    1:  0.23     Ideal     E     SI2  61.5    55   326  3.95  3.98  2.43
    2:  0.21   Premium     E     SI1  59.8    61   326  3.89  3.84  2.31
    3:  0.23      Good     E     VS1  56.9    65   327  4.05  4.07  2.31
    4:  0.29   Premium     I     VS2  62.4    58   334  4.20  4.23  2.63
    5:  0.31      Good     J     SI2  63.3    58   335  4.34  4.35  2.75
   ---                                                                  
52685:  0.72     Ideal     D     SI1  60.8    57  2757  5.75  5.76  3.50
52686:  0.72      Good     D     SI1  63.1    55  2757  5.69  5.75  3.61
52687:  0.70 Very Good     D     SI1  62.8    60  2757  5.66  5.68  3.56
52688:  0.86   Premium     H     SI2  61.0    58  2757  6.15  6.12  3.74
52689:  0.75     Ideal     D     SI2  62.2    55  2757  5.83  5.87  3.64
## Testing how your function handles an input that isn't numeric
remove_outliers(diamonds, 
                price, 
                color)
Warning in remove_outliers(diamonds, price, color): The following columns are
not numeric and will be skipped: color
       carat       cut color clarity depth table price     x     y     z
       <num>     <ord> <ord>   <ord> <num> <num> <int> <num> <num> <num>
    1:  0.23     Ideal     E     SI2  61.5    55   326  3.95  3.98  2.43
    2:  0.21   Premium     E     SI1  59.8    61   326  3.89  3.84  2.31
    3:  0.23      Good     E     VS1  56.9    65   327  4.05  4.07  2.31
    4:  0.29   Premium     I     VS2  62.4    58   334  4.20  4.23  2.63
    5:  0.31      Good     J     SI2  63.3    58   335  4.34  4.35  2.75
   ---                                                                  
52730:  0.72     Ideal     D     SI1  60.8    57  2757  5.75  5.76  3.50
52731:  0.72      Good     D     SI1  63.1    55  2757  5.69  5.75  3.61
52732:  0.70 Very Good     D     SI1  62.8    60  2757  5.66  5.68  3.56
52733:  0.86   Premium     H     SI2  61.0    58  2757  6.15  6.12  3.74
52734:  0.75     Ideal     D     SI2  62.2    55  2757  5.83  5.87  3.64
## Testing how your function handles a non-default sd_thresh
remove_outliers(diamonds, 
                price,
                x, 
                y, 
                z, 
                sd_thresh = 2)
       carat       cut color clarity depth table price     x     y     z
       <num>     <ord> <ord>   <ord> <num> <num> <int> <num> <num> <num>
    1:  0.23     Ideal     E     SI2  61.5    55   326  3.95  3.98  2.43
    2:  0.21   Premium     E     SI1  59.8    61   326  3.89  3.84  2.31
    3:  0.23      Good     E     VS1  56.9    65   327  4.05  4.07  2.31
    4:  0.29   Premium     I     VS2  62.4    58   334  4.20  4.23  2.63
    5:  0.31      Good     J     SI2  63.3    58   335  4.34  4.35  2.75
   ---                                                                  
50095:  0.72     Ideal     D     SI1  60.8    57  2757  5.75  5.76  3.50
50096:  0.72      Good     D     SI1  63.1    55  2757  5.69  5.75  3.61
50097:  0.70 Very Good     D     SI1  62.8    60  2757  5.66  5.68  3.56
50098:  0.86   Premium     H     SI2  61.0    58  2757  6.15  6.12  3.74
50099:  0.75     Ideal     D     SI2  62.2    55  2757  5.83  5.87  3.64

Goal 2: Write a function that imputes missing values for numeric variables in a dataset. The user should be able to supply the dataset, the variables to impute values for, and a function to use when imputing.

impute_missing_dt <- function(data, ..., impute_fun = mean) { 
  #Capture column names
  cols <- as.character(substitute(list(...)))[-1]
  
  #Check for columns which are non-numeric
  non_numeric <- cols[!sapply(data[, cols, drop = FALSE], is.numeric)]
  
  if (length(non_numeric) > 0) {
    warning("The following columns are not numeric and will be skipped: ",
            paste(non_numeric, collapse = ", "))
  }
  
  #Filter only numeric columns
  numeric_cols <- setdiff(cols, non_numeric)
  
  data |> 
    mutate(across(
      .cols = all_of(numeric_cols), 
      .fns = ~ replace_na(.x, impute_fun(.x, na.rm = TRUE)) #Impute using the specified function
    ))
}

Testing My Function!

## Testing how your function handles multiple input variables
impute_missing_dt(nycflights13::flights, 
               arr_delay, 
               dep_delay) 
# A tibble: 336,776 × 19
    year month   day dep_time sched_dep_time dep_delay arr_time sched_arr_time
   <int> <int> <int>    <int>          <int>     <dbl>    <int>          <int>
 1  2013     1     1      517            515         2      830            819
 2  2013     1     1      533            529         4      850            830
 3  2013     1     1      542            540         2      923            850
 4  2013     1     1      544            545        -1     1004           1022
 5  2013     1     1      554            600        -6      812            837
 6  2013     1     1      554            558        -4      740            728
 7  2013     1     1      555            600        -5      913            854
 8  2013     1     1      557            600        -3      709            723
 9  2013     1     1      557            600        -3      838            846
10  2013     1     1      558            600        -2      753            745
# ℹ 336,766 more rows
# ℹ 11 more variables: arr_delay <dbl>, carrier <chr>, flight <int>,
#   tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>,
#   hour <dbl>, minute <dbl>, time_hour <dttm>
## Testing how your function handles an input that isn't numeric
impute_missing_dt(nycflights13::flights, 
               arr_delay, 
               carrier)
Warning in impute_missing_dt(nycflights13::flights, arr_delay, carrier): The
following columns are not numeric and will be skipped: carrier
# A tibble: 336,776 × 19
    year month   day dep_time sched_dep_time dep_delay arr_time sched_arr_time
   <int> <int> <int>    <int>          <int>     <dbl>    <int>          <int>
 1  2013     1     1      517            515         2      830            819
 2  2013     1     1      533            529         4      850            830
 3  2013     1     1      542            540         2      923            850
 4  2013     1     1      544            545        -1     1004           1022
 5  2013     1     1      554            600        -6      812            837
 6  2013     1     1      554            558        -4      740            728
 7  2013     1     1      555            600        -5      913            854
 8  2013     1     1      557            600        -3      709            723
 9  2013     1     1      557            600        -3      838            846
10  2013     1     1      558            600        -2      753            745
# ℹ 336,766 more rows
# ℹ 11 more variables: arr_delay <dbl>, carrier <chr>, flight <int>,
#   tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>,
#   hour <dbl>, minute <dbl>, time_hour <dttm>
## Testing how your function handles a non-default impute_fun
impute_missing_dt(nycflights13::flights, 
               arr_delay, 
               dep_delay, 
               impute_fun = median)
# A tibble: 336,776 × 19
    year month   day dep_time sched_dep_time dep_delay arr_time sched_arr_time
   <int> <int> <int>    <int>          <int>     <dbl>    <int>          <int>
 1  2013     1     1      517            515         2      830            819
 2  2013     1     1      533            529         4      850            830
 3  2013     1     1      542            540         2      923            850
 4  2013     1     1      544            545        -1     1004           1022
 5  2013     1     1      554            600        -6      812            837
 6  2013     1     1      554            558        -4      740            728
 7  2013     1     1      555            600        -5      913            854
 8  2013     1     1      557            600        -3      709            723
 9  2013     1     1      557            600        -3      838            846
10  2013     1     1      558            600        -2      753            745
# ℹ 336,766 more rows
# ℹ 11 more variables: arr_delay <dbl>, carrier <chr>, flight <int>,
#   tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>,
#   hour <dbl>, minute <dbl>, time_hour <dttm>

Primary Function

Goal 3: Write a fit_model() function that fits a specified linear regression model for a specified dataset. The function should:

  • allow the user to specify if outliers should be removed (TRUE or FALSE)
  • allow the user to specify if missing observations should be imputed (TRUE or FALSE)

If either option is TRUE, the function should call my remove_outliers() or impute_missing() functions to modify the data before the regression model is fit.

fit_model <- function(df, impute_missing = FALSE, remove_outliers = FALSE, mod_formula, ...) {
    #Check if need imputation
    if (impute_missing) {
        df <- impute_missing_dt(df, ...)  #Call my imputation function
    }
    
    #Check if need outlier removal
    if (remove_outliers) {
        df <- remove_outliers(df, ...)  #Call my outlier removal function
    }
    
    #Fit the linear regression model
    model <- lm(mod_formula, data = df)
    
    #Return the model
    return(model)
}

Testing My Function!

fit_model(
  diamonds,
  mod_formula = price ~ carat + cut,
  remove_outliers = TRUE,
  impute_missing = TRUE,
  price, 
  carat
)

Call:
lm(formula = mod_formula, data = df)

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
   -2460.16      7526.96      1059.65      -410.54       295.80        82.62  

Iteration

In the diamonds dataset, I want to understand the relationship between price and size (carat). Specifically, I want to explore variation along two choices:

  1. The variables included in the model. I’ll explore 3 sets of variables:

    • No further variables (just price and carat)
    • Adjusting for cut
    • Adjusting for cut and clarity
    • Adjusting for cut, clarity, and color
  2. Whether or not to impute missing values

  3. Whether or not to remove outliers in the carat variable (I’ll define outliers as cases whose carat is over 3 SDs away from the mean).

Parameters

First, I need to define the set of parameters I want to iterate the fit_model() function over. The tidyr package has a useful function called crossing() that is useful for generating argument combinations. For each argument, I specify all possible values for that argument and crossing() generates all combinations.

df_arg_combos <- crossing(
    impute = c(TRUE, FALSE),
    remove_outliers = c(TRUE, FALSE), 
    mod = c(y ~ x1, 
            y ~ x1 + x2)
)
df_arg_combos

Goal 4: Use crossing() to create the data frame of argument combinations for our analyses.

df_arg_combos <- crossing(
  impute_missing = c(TRUE, FALSE),
  remove_outliers = c(TRUE, FALSE),
  mod_formula = list(
    price ~ carat, #No additional variables
    price ~ carat + cut, #Adjusting for cut
    price ~ carat + cut + clarity, #Adjusting for cut and clarity
    price ~ carat + cut + clarity + color #Adjusting for all variables
  )
)

Iterating Over the Parameters

I’ve arrived at the final step!

Goal 5: Use pmap() from purrr to apply the fit_model() function to every combination of arguments from `diamonds.

#Apply fit_model to all combinations of arguments using pmap
results <- pmap(df_arg_combos, fit_model, df = diamonds, price, carat)

#Fitted models for each combination of arguments
results
[[1]]

Call:
lm(formula = mod_formula, data = df)

Coefficients:
(Intercept)        carat  
      -2256         7756  


[[2]]

Call:
lm(formula = mod_formula, data = df)

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
   -2701.38      7871.08      1239.80      -528.60       367.91        74.59  


[[3]]

Call:
lm(formula = mod_formula, data = df)

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
  -3187.540     8472.026      713.804     -334.503      188.482        1.663  
  clarity.L    clarity.Q    clarity.C    clarity^4    clarity^5    clarity^6  
   4011.681    -1821.922      917.658     -430.047      257.141       26.909  
  clarity^7  
    186.742  


[[4]]

Call:
lm(formula = mod_formula, data = df)

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
  -3710.603     8886.129      698.907     -327.686      180.565       -1.207  
  clarity.L    clarity.Q    clarity.C    clarity^4    clarity^5    clarity^6  
   4217.535    -1832.406      923.273     -361.995      216.616        2.105  
  clarity^7      color.L      color.Q      color.C      color^4      color^5  
    110.340    -1910.288     -627.954     -171.960       21.678      -85.943  
    color^6  
    -49.986  


[[5]]

Call:
lm(formula = mod_formula, data = df)

Coefficients:
(Intercept)        carat  
      -2067         7411  


[[6]]

Call:
lm(formula = mod_formula, data = df)

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
   -2460.16      7526.96      1059.65      -410.54       295.80        82.62  


[[7]]

Call:
lm(formula = mod_formula, data = df)

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
   -2897.60      8118.13       618.65      -268.13       148.57         8.10  
  clarity.L    clarity.Q    clarity.C    clarity^4    clarity^5    clarity^6  
    3463.17     -1505.52       609.45      -286.81       158.48        71.83  
  clarity^7  
     183.33  


[[8]]

Call:
lm(formula = mod_formula, data = df)

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
  -3376.856     8513.377      607.668     -263.835      142.453        5.506  
  clarity.L    clarity.Q    clarity.C    clarity^4    clarity^5    clarity^6  
   3654.785    -1503.171      612.049     -225.568      124.920       47.356  
  clarity^7      color.L      color.Q      color.C      color^4      color^5  
    112.889    -1714.316     -539.239     -127.282       43.649      -60.866  
    color^6  
    -49.333  


[[9]]

Call:
lm(formula = mod_formula, data = df)

Coefficients:
(Intercept)        carat  
      -2256         7756  


[[10]]

Call:
lm(formula = mod_formula, data = df)

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
   -2701.38      7871.08      1239.80      -528.60       367.91        74.59  


[[11]]

Call:
lm(formula = mod_formula, data = df)

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
  -3187.540     8472.026      713.804     -334.503      188.482        1.663  
  clarity.L    clarity.Q    clarity.C    clarity^4    clarity^5    clarity^6  
   4011.681    -1821.922      917.658     -430.047      257.141       26.909  
  clarity^7  
    186.742  


[[12]]

Call:
lm(formula = mod_formula, data = df)

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
  -3710.603     8886.129      698.907     -327.686      180.565       -1.207  
  clarity.L    clarity.Q    clarity.C    clarity^4    clarity^5    clarity^6  
   4217.535    -1832.406      923.273     -361.995      216.616        2.105  
  clarity^7      color.L      color.Q      color.C      color^4      color^5  
    110.340    -1910.288     -627.954     -171.960       21.678      -85.943  
    color^6  
    -49.986  


[[13]]

Call:
lm(formula = mod_formula, data = df)

Coefficients:
(Intercept)        carat  
      -2067         7411  


[[14]]

Call:
lm(formula = mod_formula, data = df)

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
   -2460.16      7526.96      1059.65      -410.54       295.80        82.62  


[[15]]

Call:
lm(formula = mod_formula, data = df)

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
   -2897.60      8118.13       618.65      -268.13       148.57         8.10  
  clarity.L    clarity.Q    clarity.C    clarity^4    clarity^5    clarity^6  
    3463.17     -1505.52       609.45      -286.81       158.48        71.83  
  clarity^7  
     183.33  


[[16]]

Call:
lm(formula = mod_formula, data = df)

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
  -3376.856     8513.377      607.668     -263.835      142.453        5.506  
  clarity.L    clarity.Q    clarity.C    clarity^4    clarity^5    clarity^6  
   3654.785    -1503.171      612.049     -225.568      124.920       47.356  
  clarity^7      color.L      color.Q      color.C      color^4      color^5  
    112.889    -1714.316     -539.239     -127.282       43.649      -60.866  
    color^6  
    -49.333