pipelinex.extras.ops.ignite.declaratives package¶
Submodules¶
pipelinex.extras.ops.ignite.declaratives.declarative_trainer module¶
-
class
pipelinex.extras.ops.ignite.declaratives.declarative_trainer.
NetworkTrain
(loss_fn=None, epochs=None, seed=None, optimizer=None, optimizer_params={}, train_data_loader_params={}, val_data_loader_params={}, evaluation_metrics=None, evaluate_train_data=None, evaluate_val_data=None, progress_update=None, scheduler=None, scheduler_params={}, model_checkpoint=None, model_checkpoint_params={}, early_stopping_params={}, time_limit=None, train_dataset_size_limit=None, val_dataset_size_limit=None, cudnn_deterministic=None, cudnn_benchmark=None, mlflow_logging=True, train_params={})[source]¶ Bases:
object
Create a trainer for a supervised PyTorch model.
- Parameters:
loss_fn (callable) – Loss function used to train. Accepts an instance of loss functions at https://pytorch.org/docs/stable/nn.html#loss-functions
epochs (int, optional) – Max epochs to train
seed (int, optional) – Random seed for training.
optimizer (torch.optim, optional) – Optimizer used to train. Accepts optimizers at https://pytorch.org/docs/stable/optim.html
optimizer_params (dict, optional) – Parameters for optimizer.
train_data_loader_params (dict, optional) – Parameters for data loader for training. Accepts args at https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
val_data_loader_params (dict, optional) – Parameters for data loader for validation. Accepts args at https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
evaluation_metrics (dict, optional) – Metrics to compute for evaluation. Accepts dict of metrics at https://pytorch.org/ignite/metrics.html
evaluate_train_data (str, optional) – When to compute evaluation_metrics using training dataset. Accepts events at https://pytorch.org/ignite/engine.html#ignite.engine.Events
evaluate_val_data (str, optional) – When to compute evaluation_metrics using validation dataset. Accepts events at https://pytorch.org/ignite/engine.html#ignite.engine.Events
progress_update (bool, optional) – Whether to show progress bar using tqdm package
scheduler (ignite.contrib.handle.param_scheduler.ParamScheduler, optional) – Param scheduler. Accepts a ParamScheduler at https://pytorch.org/ignite/contrib/handlers.html#module-ignite.contrib.handlers.param_scheduler
scheduler_params (dict, optional) – Parameters for scheduler
model_checkpoint (ignite.handlers.ModelCheckpoint, optional) – Model Checkpoint. Accepts a ModelCheckpoint at https://pytorch.org/ignite/handlers.html#ignite.handlers.ModelCheckpoint
model_checkpoint_params (dict, optional) – Parameters for ModelCheckpoint at https://pytorch.org/ignite/handlers.html#ignite.handlers.ModelCheckpoint
early_stopping_params (dict, optional) – Parameters for EarlyStopping at https://pytorch.org/ignite/handlers.html#ignite.handlers.EarlyStopping
time_limit (int, optioinal) – Time limit for training in seconds.
train_dataset_size_limit (int, optional) – If specified, only the subset of training dataset is used. Useful for quick preliminary check before using the whole dataset.
val_dataset_size_limit (int, optional) – If specified, only the subset of validation dataset is used. useful for qucik preliminary check before using the whole dataset.
cudnn_deterministic (bool, optional) – Value for torch.backends.cudnn.deterministic. See https://pytorch.org/docs/stable/notes/randomness.html for details.
cudnn_benchmark (bool, optional) – Value for torch.backends.cudnn.benchmark. See https://pytorch.org/docs/stable/notes/randomness.html for details.
mlflow_logging (bool, optional) – If True and MLflow is installed, MLflow logging is enabled.
- Returns:
a callable to train a PyTorch model.
- Return type:
trainer (callable)
-
__init__
(loss_fn=None, epochs=None, seed=None, optimizer=None, optimizer_params={}, train_data_loader_params={}, val_data_loader_params={}, evaluation_metrics=None, evaluate_train_data=None, evaluate_val_data=None, progress_update=None, scheduler=None, scheduler_params={}, model_checkpoint=None, model_checkpoint_params={}, early_stopping_params={}, time_limit=None, train_dataset_size_limit=None, val_dataset_size_limit=None, cudnn_deterministic=None, cudnn_benchmark=None, mlflow_logging=True, train_params={})[source]¶ Initialize self. See help(type(self)) for accurate signature.