import copy
import logging
from pathlib import Path
import ignite
import numpy as np
import torch
from ignite.contrib.handlers.mlflow_logger import (
MLflowLogger,
OutputHandler,
global_step_from_engine,
)
from ignite.contrib.handlers.tqdm_logger import ProgressBar
from ignite.engine import Events, create_supervised_evaluator, create_supervised_trainer
from ignite.handlers import EarlyStopping, ModelCheckpoint
from ignite.metrics import RunningAverage
from torch.utils.data import DataLoader
from ..handlers.time_limit import TimeLimit
log = logging.getLogger(__name__)
__all__ = ["NetworkTrain"]
[docs]class NetworkTrain:
"""Create a trainer for a supervised PyTorch model.
Args:
loss_fn (callable): Loss function used to train.
Accepts an instance of loss functions at https://pytorch.org/docs/stable/nn.html#loss-functions
epochs (int, optional): Max epochs to train
seed (int, optional): Random seed for training.
optimizer (torch.optim, optional): Optimizer used to train.
Accepts optimizers at https://pytorch.org/docs/stable/optim.html
optimizer_params (dict, optional): Parameters for optimizer.
train_data_loader_params (dict, optional): Parameters for data loader for training.
Accepts args at https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
val_data_loader_params (dict, optional): Parameters for data loader for validation.
Accepts args at https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
evaluation_metrics (dict, optional): Metrics to compute for evaluation.
Accepts dict of metrics at https://pytorch.org/ignite/metrics.html
evaluate_train_data (str, optional): When to compute evaluation_metrics using training dataset.
Accepts events at https://pytorch.org/ignite/engine.html#ignite.engine.Events
evaluate_val_data (str, optional): When to compute evaluation_metrics using validation dataset.
Accepts events at https://pytorch.org/ignite/engine.html#ignite.engine.Events
progress_update (bool, optional): Whether to show progress bar using tqdm package
scheduler (ignite.contrib.handle.param_scheduler.ParamScheduler, optional): Param scheduler.
Accepts a ParamScheduler at
https://pytorch.org/ignite/contrib/handlers.html#module-ignite.contrib.handlers.param_scheduler
scheduler_params (dict, optional): Parameters for scheduler
model_checkpoint (ignite.handlers.ModelCheckpoint, optional): Model Checkpoint.
Accepts a ModelCheckpoint at https://pytorch.org/ignite/handlers.html#ignite.handlers.ModelCheckpoint
model_checkpoint_params (dict, optional): Parameters for ModelCheckpoint at
https://pytorch.org/ignite/handlers.html#ignite.handlers.ModelCheckpoint
early_stopping_params (dict, optional): Parameters for EarlyStopping at
https://pytorch.org/ignite/handlers.html#ignite.handlers.EarlyStopping
time_limit (int, optioinal): Time limit for training in seconds.
train_dataset_size_limit (int, optional): If specified, only the subset of training dataset is used.
Useful for quick preliminary check before using the whole dataset.
val_dataset_size_limit (int, optional): If specified, only the subset of validation dataset is used.
useful for qucik preliminary check before using the whole dataset.
cudnn_deterministic (bool, optional): Value for torch.backends.cudnn.deterministic.
See https://pytorch.org/docs/stable/notes/randomness.html for details.
cudnn_benchmark (bool, optional): Value for torch.backends.cudnn.benchmark.
See https://pytorch.org/docs/stable/notes/randomness.html for details.
mlflow_logging (bool, optional): If True and MLflow is installed, MLflow logging is enabled.
Returns:
trainer (callable): a callable to train a PyTorch model.
"""
[docs] def __init__(
self,
loss_fn=None,
epochs=None,
seed=None,
optimizer=None,
optimizer_params=dict(),
train_data_loader_params=dict(),
val_data_loader_params=dict(),
evaluation_metrics=None,
evaluate_train_data=None,
evaluate_val_data=None,
progress_update=None,
scheduler=None,
scheduler_params=dict(),
model_checkpoint=None,
model_checkpoint_params=dict(),
early_stopping_params=dict(),
time_limit=None,
train_dataset_size_limit=None,
val_dataset_size_limit=None,
cudnn_deterministic=None,
cudnn_benchmark=None,
mlflow_logging=True,
train_params=dict(),
):
self.train_params = dict(
loss_fn=loss_fn,
epochs=epochs,
seed=seed,
optimizer=optimizer,
optimizer_params=optimizer_params,
train_data_loader_params=train_data_loader_params,
val_data_loader_params=val_data_loader_params,
evaluation_metrics=evaluation_metrics,
evaluate_train_data=evaluate_train_data,
evaluate_val_data=evaluate_val_data,
progress_update=progress_update,
scheduler=scheduler,
scheduler_params=scheduler_params,
model_checkpoint=model_checkpoint,
model_checkpoint_params=model_checkpoint_params,
early_stopping_params=early_stopping_params,
time_limit=time_limit,
train_dataset_size_limit=train_dataset_size_limit,
val_dataset_size_limit=val_dataset_size_limit,
cudnn_deterministic=cudnn_deterministic,
cudnn_benchmark=cudnn_benchmark,
)
self.train_params.update(train_params)
self.mlflow_logging = mlflow_logging
def __call__(self, model, train_dataset, val_dataset=None, **_):
"""Train a PyTorch model.
Args:
model (torch.nn.Module): PyTorch model to train.
train_dataset (torch.utils.data.Dataset): Dataset used to train.
val_dataset (torch.utils.data.Dataset, optional): Dataset used to validate.
Returns:
trained_model (torch.nn.Module): Trained PyTorch model.
"""
assert train_dataset is not None
train_params = self.train_params
mlflow_logging = self.mlflow_logging
if mlflow_logging:
try:
import mlflow # NOQA
except ImportError:
log.warning("Failed to import mlflow. MLflow logging is disabled.")
mlflow_logging = False
loss_fn = train_params.get("loss_fn")
assert loss_fn
epochs = train_params.get("epochs")
seed = train_params.get("seed")
optimizer = train_params.get("optimizer")
assert optimizer
optimizer_params = train_params.get("optimizer_params", dict())
train_dataset_size_limit = train_params.get("train_dataset_size_limit")
if train_dataset_size_limit:
train_dataset = PartialDataset(train_dataset, train_dataset_size_limit)
log.info("train dataset size is set to {}".format(len(train_dataset)))
val_dataset_size_limit = train_params.get("val_dataset_size_limit")
if val_dataset_size_limit and (val_dataset is not None):
val_dataset = PartialDataset(val_dataset, val_dataset_size_limit)
log.info("val dataset size is set to {}".format(len(val_dataset)))
train_data_loader_params = train_params.get("train_data_loader_params", dict())
val_data_loader_params = train_params.get("val_data_loader_params", dict())
evaluation_metrics = train_params.get("evaluation_metrics")
evaluate_train_data = train_params.get("evaluate_train_data")
evaluate_val_data = train_params.get("evaluate_val_data")
progress_update = train_params.get("progress_update")
scheduler = train_params.get("scheduler")
scheduler_params = train_params.get("scheduler_params", dict())
model_checkpoint = train_params.get("model_checkpoint")
model_checkpoint_params = train_params.get("model_checkpoint_params")
early_stopping_params = train_params.get("early_stopping_params")
time_limit = train_params.get("time_limit")
cudnn_deterministic = train_params.get("cudnn_deterministic")
cudnn_benchmark = train_params.get("cudnn_benchmark")
if seed:
torch.manual_seed(seed)
np.random.seed(seed)
if cudnn_deterministic:
torch.backends.cudnn.deterministic = cudnn_deterministic
if cudnn_benchmark:
torch.backends.cudnn.benchmark = cudnn_benchmark
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
optimizer_ = optimizer(model.parameters(), **optimizer_params)
trainer = create_supervised_trainer(
model, optimizer_, loss_fn=loss_fn, device=device
)
train_data_loader_params.setdefault("shuffle", True)
train_data_loader_params.setdefault("drop_last", True)
train_data_loader_params["batch_size"] = _clip_batch_size(
train_data_loader_params.get("batch_size", 1), train_dataset, "train"
)
train_loader = DataLoader(train_dataset, **train_data_loader_params)
RunningAverage(output_transform=lambda x: x, alpha=0.98).attach(
trainer, "ema_loss"
)
RunningAverage(output_transform=lambda x: x, alpha=2 ** (-1022)).attach(
trainer, "batch_loss"
)
if scheduler:
class ParamSchedulerSavingAsMetric(
ParamSchedulerSavingAsMetricMixIn, scheduler
):
pass
cycle_epochs = scheduler_params.pop("cycle_epochs", 1)
scheduler_params.setdefault(
"cycle_size", int(cycle_epochs * len(train_loader))
)
scheduler_params.setdefault("param_name", "lr")
scheduler_ = ParamSchedulerSavingAsMetric(optimizer_, **scheduler_params)
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler_)
if evaluate_train_data:
evaluator_train = create_supervised_evaluator(
model, metrics=evaluation_metrics, device=device
)
if evaluate_val_data:
val_data_loader_params["batch_size"] = _clip_batch_size(
val_data_loader_params.get("batch_size", 1), val_dataset, "val"
)
val_loader = DataLoader(val_dataset, **val_data_loader_params)
evaluator_val = create_supervised_evaluator(
model, metrics=evaluation_metrics, device=device
)
if model_checkpoint_params:
assert isinstance(model_checkpoint_params, dict)
minimize = model_checkpoint_params.pop("minimize", True)
save_interval = model_checkpoint_params.get("save_interval", None)
if not save_interval:
model_checkpoint_params.setdefault(
"score_function", get_score_function("ema_loss", minimize=minimize)
)
model_checkpoint_params.setdefault("score_name", "ema_loss")
mc = model_checkpoint(**model_checkpoint_params)
trainer.add_event_handler(Events.EPOCH_COMPLETED, mc, {"model": model})
if early_stopping_params:
assert isinstance(early_stopping_params, dict)
metric = early_stopping_params.pop("metric", None)
assert (metric is None) or (metric in evaluation_metrics)
minimize = early_stopping_params.pop("minimize", False)
if metric:
assert (
"score_function" not in early_stopping_params
), "Remove either 'metric' or 'score_function' from early_stopping_params: {}".format(
early_stopping_params
)
early_stopping_params["score_function"] = get_score_function(
metric, minimize=minimize
)
es = EarlyStopping(trainer=trainer, **early_stopping_params)
if evaluate_val_data:
evaluator_val.add_event_handler(Events.COMPLETED, es)
elif evaluate_train_data:
evaluator_train.add_event_handler(Events.COMPLETED, es)
elif early_stopping_params:
log.warning(
"Early Stopping is disabled because neither "
"evaluate_val_data nor evaluate_train_data is set True."
)
if time_limit:
assert isinstance(time_limit, (int, float))
tl = TimeLimit(limit_sec=time_limit)
trainer.add_event_handler(Events.ITERATION_COMPLETED, tl)
pbar = None
if progress_update:
if not isinstance(progress_update, dict):
progress_update = dict()
progress_update.setdefault("persist", True)
progress_update.setdefault("desc", "")
pbar = ProgressBar(**progress_update)
pbar.attach(trainer, ["ema_loss"])
else:
def log_train_metrics(engine):
log.info(
"[Epoch: {} | {}]".format(engine.state.epoch, engine.state.metrics)
)
trainer.add_event_handler(Events.EPOCH_COMPLETED, log_train_metrics)
if evaluate_train_data:
def log_evaluation_train_data(engine):
evaluator_train.run(train_loader)
train_report = _get_report_str(engine, evaluator_train, "Train Data")
if pbar:
pbar.log_message(train_report)
else:
log.info(train_report)
eval_train_event = (
Events[evaluate_train_data]
if isinstance(evaluate_train_data, str)
else Events.EPOCH_COMPLETED
)
trainer.add_event_handler(eval_train_event, log_evaluation_train_data)
if evaluate_val_data:
def log_evaluation_val_data(engine):
evaluator_val.run(val_loader)
val_report = _get_report_str(engine, evaluator_val, "Val Data")
if pbar:
pbar.log_message(val_report)
else:
log.info(val_report)
eval_val_event = (
Events[evaluate_val_data]
if isinstance(evaluate_val_data, str)
else Events.EPOCH_COMPLETED
)
trainer.add_event_handler(eval_val_event, log_evaluation_val_data)
if mlflow_logging:
mlflow_logger = MLflowLogger()
logging_params = {
"train_n_samples": len(train_dataset),
"train_n_batches": len(train_loader),
"optimizer": _name(optimizer),
"loss_fn": _name(loss_fn),
"pytorch_version": torch.__version__,
"ignite_version": ignite.__version__,
}
logging_params.update(_loggable_dict(optimizer_params, "optimizer"))
logging_params.update(_loggable_dict(train_data_loader_params, "train"))
if scheduler:
logging_params.update({"scheduler": _name(scheduler)})
logging_params.update(_loggable_dict(scheduler_params, "scheduler"))
if evaluate_val_data:
logging_params.update(
{
"val_n_samples": len(val_dataset),
"val_n_batches": len(val_loader),
}
)
logging_params.update(_loggable_dict(val_data_loader_params, "val"))
mlflow_logger.log_params(logging_params)
batch_metric_names = ["batch_loss", "ema_loss"]
if scheduler:
batch_metric_names.append(scheduler_params.get("param_name"))
mlflow_logger.attach(
trainer,
log_handler=OutputHandler(
tag="step",
metric_names=batch_metric_names,
global_step_transform=global_step_from_engine(trainer),
),
event_name=Events.ITERATION_COMPLETED,
)
if evaluate_train_data:
mlflow_logger.attach(
evaluator_train,
log_handler=OutputHandler(
tag="train",
metric_names=list(evaluation_metrics.keys()),
global_step_transform=global_step_from_engine(trainer),
),
event_name=Events.COMPLETED,
)
if evaluate_val_data:
mlflow_logger.attach(
evaluator_val,
log_handler=OutputHandler(
tag="val",
metric_names=list(evaluation_metrics.keys()),
global_step_transform=global_step_from_engine(trainer),
),
event_name=Events.COMPLETED,
)
trainer.run(train_loader, max_epochs=epochs)
try:
if pbar and pbar.pbar:
pbar.pbar.close()
except Exception as e:
log.error(e, exc_info=True)
model = load_latest_model(model_checkpoint_params)(model)
return model
def get_score_function(metric, minimize=False):
def _score_function(engine):
m = engine.state.metrics.get(metric)
return -m if minimize else m
return _score_function
def load_latest_model(model_checkpoint_params=None):
if model_checkpoint_params and "model_checkpoint_params" in model_checkpoint_params:
model_checkpoint_params = model_checkpoint_params.get("model_checkpoint_params")
def _load_latest_model(model=None):
if model_checkpoint_params:
try:
dirname = model_checkpoint_params.get("dirname")
assert dirname
dir_glob = Path(dirname).glob("*.pth")
files = [str(p) for p in dir_glob if p.is_file()]
if len(files) >= 1:
model_path = sorted(files)[-1]
log.info("Model path: {}".format(model_path))
loaded = torch.load(model_path)
save_as_state_dict = model_checkpoint_params.get(
"save_as_state_dict", True
)
if save_as_state_dict:
assert model
model.load_state_dict(loaded)
else:
model = loaded
else:
log.warning("Model not found at: {}".format(dirname))
except Exception as e:
log.error(e, exc_info=True)
return model
return _load_latest_model
def _name(obj):
return getattr(obj, "__name__", None) or getattr(obj.__class__, "__name__", "_")
def _clip_batch_size(batch_size, dataset, tag=""):
dataset_size = len(dataset)
if batch_size > dataset_size:
log.warning(
"[{}] batch size ({}) is clipped to dataset size ({})".format(
tag, batch_size, dataset_size
)
)
return dataset_size
else:
return batch_size
def _get_report_str(engine, evaluator, tag=""):
report_str = "[Epoch: {} | {} | Metrics: {}]".format(
engine.state.epoch, tag, evaluator.state.metrics
)
return report_str
def _loggable_dict(d, prefix=None):
return {
("{}_{}".format(prefix, k) if prefix else k): (
"{}".format(v) if isinstance(v, (tuple, list, dict, set)) else v
)
for k, v in d.items()
}
class ParamSchedulerSavingAsMetricMixIn:
"""Base code:
https://github.com/pytorch/ignite/blob/v0.2.1/ignite/contrib/handlers/param_scheduler.py#L49
https://github.com/pytorch/ignite/blob/v0.2.1/ignite/contrib/handlers/param_scheduler.py#L163
"""
def __call__(self, engine, name=None):
if self.event_index != 0 and self.event_index % self.cycle_size == 0:
self.event_index = 0
self.cycle_size *= self.cycle_mult
self.cycle += 1
self.start_value *= self.start_value_mult
self.end_value *= self.end_value_mult
value = self.get_param()
for param_group in self.optimizer_param_groups:
param_group[self.param_name] = value
if name is None:
name = self.param_name
if self.save_history:
if not hasattr(engine.state, "param_history"):
setattr(engine.state, "param_history", {})
engine.state.param_history.setdefault(name, [])
values = [pg[self.param_name] for pg in self.optimizer_param_groups]
engine.state.param_history[name].append(values)
self.event_index += 1
if not hasattr(engine.state, "metrics"):
setattr(engine.state, "metrics", {})
engine.state.metrics[self.param_name] = value # Save as a metric
class PartialDataset:
def __init__(self, dataset, size):
size = int(size)
assert hasattr(dataset, "__getitem__")
assert hasattr(dataset, "__len__")
assert dataset.__len__() >= size
self.dataset = dataset
self.size = size
def __len__(self):
return self.size
def __getitem__(self, item):
return self.dataset[item]
class CopiedPartialDataset:
def __init__(self, dataset, size):
size = int(size)
assert hasattr(dataset, "__getitem__")
assert hasattr(dataset, "__len__")
assert dataset.__len__() >= size
self.dataset = [copy.deepcopy(dataset[i]) for i in range(size)]
self.size = size
def __len__(self):
return self.size
def __getitem__(self, item):
return self.dataset[item]
class GetPartialDataset:
def __init__(self, size):
self.size = size
def __call__(self, dataset):
return CopiedPartialDataset(dataset, self.size)