Use the EQRN_fit_restart()
wrapper instead, with data_type="iid"
, for better stability using fitting restart.
Usage
EQRN_fit(
X,
y,
intermediate_quantiles,
interm_lvl,
shape_fixed = FALSE,
net_structure = c(5, 3, 3),
hidden_fct = torch::nnf_sigmoid,
p_drop = 0,
intermediate_q_feature = TRUE,
learning_rate = 1e-04,
L2_pen = 0,
shape_penalty = 0,
scale_features = TRUE,
n_epochs = 500,
batch_size = 256,
X_valid = NULL,
y_valid = NULL,
quant_valid = NULL,
lr_decay = 1,
patience_decay = n_epochs,
min_lr = 0,
patience_stop = n_epochs,
tol = 1e-06,
orthogonal_gpd = TRUE,
patience_lag = 1,
optim_met = "adam",
seed = NULL,
verbose = 2,
device = default_device()
)
Arguments
- X
Matrix of covariates, for training.
- y
Response variable vector to model the extreme conditional quantile of, for training.
- intermediate_quantiles
Vector of intermediate conditional quantiles at level
interm_lvl
.- interm_lvl
Probability level for the intermediate quantiles
intermediate_quantiles
.- shape_fixed
Whether the shape estimate depends on the covariates or not (bool).
- net_structure
Vector of integers whose length determines the number of layers in the neural network and entries the number of neurons in each corresponding successive layer. If
hidden_fct=="SSNN"
, should instead be a named list with"scale"
and"shape"
vectors for the two respective sub-networks. Can also be atorch::nn_module
network with correct input and output dimensions, which overrides thehidden_fct
,shape_fixed
andp_drop
arguments.Activation function for the hidden layers. Can be either a callable function (preferably from the
torch
library), or one of the the strings"SNN"
,"SSNN"
for self normalizing networks (with common or separated networks for the scale and shape estimates, respectively). In the latter cases,shape_fixed
has no effect.- p_drop
Probability parameter for dropout before each hidden layer for regularization during training.
alpha-dropout
is used with SNNs.- intermediate_q_feature
Whether to use the
intermediate_quantiles
as an additional covariate, by appending it to theX
matrix (bool).- learning_rate
Initial learning rate for the optimizer during training of the neural network.
- L2_pen
L2 weight penalty parameter for regularization during training.
- shape_penalty
Penalty parameter for the shape estimate, to potentially regularize its variation from the fixed prior estimate.
- scale_features
Whether to rescale each input covariates to zero mean and unit variance before applying the network (recommended).
- n_epochs
Number of training epochs.
- batch_size
Batch size used during training.
- X_valid
Covariates in a validation set, or
NULL
. Used for monitoring validation loss during training, enabling learning-rate decay and early stopping.- y_valid
Response variable in a validation set, or
NULL
. Used for monitoring validation loss during training, enabling learning-rate decay and early stopping.- quant_valid
Intermediate conditional quantiles at level
interm_lvl
in a validation set, orNULL
. Used for monitoring validation loss during training, enabling learning-rate decay and early stopping.- lr_decay
Learning rate decay factor.
- patience_decay
Number of epochs of non-improving validation loss before a learning-rate decay is performed.
- min_lr
Minimum learning rate, under which no more decay is performed.
- patience_stop
Number of epochs of non-improving validation loss before early stopping is performed.
- tol
Tolerance for stopping training, in case of no significant training loss improvements.
- orthogonal_gpd
Whether to use the orthogonal reparametrization of the estimated GPD parameters (recommended).
- patience_lag
The validation loss is considered to be non-improving if it is larger than on any of the previous
patience_lag
epochs.- optim_met
DEPRECATED. Optimization algorithm to use during training.
"adam"
is the default.- seed
Integer random seed for reproducibility in network weight initialization.
- verbose
Amount of information printed during training (0:nothing, 1:most important, 2:everything).
- device
(optional) A
torch::torch_device()
. Defaults todefault_device()
.