Grouped Lasso regression and classification for tidyfit
Source: R/fit.group_lasso.R
dot-fit.group_lasso.Rd
Fits a linear regression or classification with a grouped L1 penalty on a 'tidyFit' R6
class. The function can be used with regress
and classify
.
Details
Hyperparameters:
lambda
(L1 penalty)
Important method arguments (passed to m
)
The Group Lasso regression is estimated using gglasso::gglasso
. The 'group' argument is a named vector passed directly to m()
(see examples). See ?gglasso
for more details. Only binomial classification is possible. Weights are ignored for classification.
Implementation
Features are standardized by default with coefficients transformed to the original scale.
If no hyperparameter grid is passed (is.null(control$lambda)
), dials::grid_regular()
is used to determine a sensible default grid. The grid size is 100. Note that the grid selection tools provided by gglasso::gglasso
cannot be used (e.g. dfmax
). This is to guarantee identical grids across groups in the tibble.
References
Yang Y, Zou H, Bhatnagar S (2020). gglasso: Group Lasso Penalized Learning Using a Unified BMD Algorithm. R package version 1.5, https://CRAN.R-project.org/package=gglasso.
See also
.fit.lasso
, .fit.blasso
, .fit.adalasso
and m
methods
Examples
# Load data
data <- tidyfit::Factor_Industry_Returns
groups <- setNames(c(1, 2, 2, 3, 3, 1), c("Mkt-RF", "SMB", "HML", "RMW", "CMA", "RF"))
# Stand-alone function
fit <- m("group_lasso", Return ~ ., data, lambda = 0.5, group = groups)
fit
#> # A tibble: 1 × 6
#> estimator_fct `size (MB)` grid_id model_object settings errors
#> <chr> <dbl> <chr> <list> <list> <chr>
#> 1 gglasso::gglasso 0 #001|001 <tidyFit> <tibble> NA/NaN argu…
# Within 'regress' function
fit <- regress(data, Return ~ ., m("group_lasso", lambda = c(0.1, 0.5), group = groups),
.mask = c("Date", "Industry"))
coef(fit)
#> # A tibble: 14 × 5
#> # Groups: model [1]
#> model term estimate grid_id model_info
#> <chr> <chr> <dbl> <chr> <list>
#> 1 group_lasso (Intercept) 0.224 #001|001 <tibble [1 × 1]>
#> 2 group_lasso CMA 0.0416 #001|001 <tibble [1 × 1]>
#> 3 group_lasso HML 0.0554 #001|001 <tibble [1 × 1]>
#> 4 group_lasso Mkt-RF 0.922 #001|001 <tibble [1 × 1]>
#> 5 group_lasso RF 0.505 #001|001 <tibble [1 × 1]>
#> 6 group_lasso RMW 0.0899 #001|001 <tibble [1 × 1]>
#> 7 group_lasso SMB 0.0147 #001|001 <tibble [1 × 1]>
#> 8 group_lasso (Intercept) 0.543 #002|001 <tibble [1 × 1]>
#> 9 group_lasso CMA 0 #002|001 <tibble [1 × 1]>
#> 10 group_lasso HML 0 #002|001 <tibble [1 × 1]>
#> 11 group_lasso Mkt-RF 0.764 #002|001 <tibble [1 × 1]>
#> 12 group_lasso RF 0 #002|001 <tibble [1 × 1]>
#> 13 group_lasso RMW 0 #002|001 <tibble [1 × 1]>
#> 14 group_lasso SMB 0.0431 #002|001 <tibble [1 × 1]>