Fits an ElasticNet regression or classification on a 'tidyFit' R6
class. The function can be used with regress
and classify
.
Details
Hyperparameters:
lambda
(penalty)alpha
(L1-L2 mixing parameter)
Important method arguments (passed to m
)
The ElasticNet regression is estimated using glmnet::glmnet
. See ?glmnet
for more details. For classification pass family = "binomial"
to ...
in m
or use classify
.
Implementation
If the response variable contains more than 2 classes, a multinomial response is used automatically.
An intercept is always included and features are standardized with coefficients transformed to the original scale.
If no hyperparameter grid is passed (is.null(control$lambda)
and is.null(control$alpha)
), dials::grid_regular()
is used to determine a sensible default grid. The grid size is 100 for lambda
and 5 for alpha
. Note that the grid selection tools provided by glmnet::glmnet
cannot be used (e.g. dfmax
). This is to guarantee identical grids across groups in the tibble.
References
Jerome Friedman, Trevor Hastie, Robert Tibshirani (2010). Regularization Paths for Generalized Linear Models via Coordinate Descent. Journal of Statistical Software, 33(1), 1-22. URL https://www.jstatsoft.org/v33/i01/.
See also
.fit.lasso
, .fit.adalasso
, .fit.ridge
and m
methods
Examples
# Load data
data <- tidyfit::Factor_Industry_Returns
# Stand-alone function
fit <- m("enet", Return ~ ., data, lambda = c(0, 0.1), alpha = 0.5)
fit
#> # A tibble: 2 × 5
#> estimator_fct `size (MB)` grid_id model_object settings
#> <chr> <dbl> <chr> <list> <list>
#> 1 glmnet::glmnet 3.06 #001|001 <tidyFit> <tibble [1 × 3]>
#> 2 glmnet::glmnet 3.06 #001|002 <tidyFit> <tibble [1 × 3]>
# Within 'regress' function
fit <- regress(data, Return ~ ., m("enet", alpha = c(0, 0.5), lambda = c(0.1)),
.mask = c("Date", "Industry"), .cv = "vfold_cv")
coef(fit)
#> # A tibble: 7 × 4
#> # Groups: model [1]
#> model term estimate model_info
#> <chr> <chr> <dbl> <list>
#> 1 enet (Intercept) 0.00557 <tibble [1 × 2]>
#> 2 enet Mkt-RF 0.953 <tibble [1 × 2]>
#> 3 enet SMB 0.0232 <tibble [1 × 2]>
#> 4 enet HML 0.0641 <tibble [1 × 2]>
#> 5 enet RMW 0.153 <tibble [1 × 2]>
#> 6 enet CMA 0.0926 <tibble [1 × 2]>
#> 7 enet RF 0.959 <tibble [1 × 2]>