from datetime import datetime, timedelta
import os
import tempfile
import torch
import logging
log = logging.getLogger(__name__)
__all__ = ["FlexibleModelCheckpoint"]
"""
Copied from https://github.com/pytorch/ignite/blob/v0.2.1/ignite/handlers/checkpoint.py
due to the change in ignite v0.3.0
"""
class ModelCheckpoint(object):
"""ModelCheckpoint handler can be used to periodically save objects to disk.
This handler expects two arguments:
- an :class:`~ignite.engine.Engine` object
- a `dict` mapping names (`str`) to objects that should be saved to disk.
See Notes and Examples for further details.
Args:
dirname (str):
Directory path where objects will be saved.
filename_prefix (str):
Prefix for the filenames to which objects will be saved. See Notes
for more details.
save_interval (int, optional):
if not None, objects will be saved to disk every `save_interval` calls to the handler.
Exactly one of (`save_interval`, `score_function`) arguments must be provided.
score_function (callable, optional):
if not None, it should be a function taking a single argument,
an :class:`~ignite.engine.Engine` object,
and return a score (`float`). Objects with highest scores will be retained.
Exactly one of (`save_interval`, `score_function`) arguments must be provided.
score_name (str, optional):
if `score_function` not None, it is possible to store its absolute value using `score_name`. See Notes for
more details.
n_saved (int, optional):
Number of objects that should be kept on disk. Older files will be removed.
atomic (bool, optional):
If True, objects are serialized to a temporary file,
and then moved to final destination, so that files are
guaranteed to not be damaged (for example if exception occures during saving).
require_empty (bool, optional):
If True, will raise exception if there are any files starting with `filename_prefix`
in the directory 'dirname'.
create_dir (bool, optional):
If True, will create directory 'dirname' if it doesnt exist.
save_as_state_dict (bool, optional):
If True, will save only the `state_dict` of the objects specified, otherwise the whole object will be saved.
Note:
This handler expects two arguments: an :class:`~ignite.engine.Engine` object and a `dict`
mapping names to objects that should be saved.
These names are used to specify filenames for saved objects.
Each filename has the following structure:
`{filename_prefix}_{name}_{step_number}.pth`.
Here, `filename_prefix` is the argument passed to the constructor,
`name` is the key in the aforementioned `dict`, and `step_number`
is incremented by `1` with every call to the handler.
If `score_function` is provided, user can store its absolute value using `score_name` in the filename.
Each filename can have the following structure:
`{filename_prefix}_{name}_{step_number}_{score_name}={abs(score_function_result)}.pth`.
For example, `score_name="val_loss"` and `score_function` that returns `-loss` (as objects with highest scores
will be retained), then saved models filenames will be `model_resnet_10_val_loss=0.1234.pth`.
Examples:
>>> import os
>>> from ignite.engine import Engine, Events
>>> from ignite.handlers import ModelCheckpoint
>>> from torch import nn
>>> trainer = Engine(lambda batch: None)
>>> handler = ModelCheckpoint('/tmp/models', 'myprefix', save_interval=2, n_saved=2, create_dir=True)
>>> model = nn.Linear(3, 3)
>>> trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {'mymodel': model})
>>> trainer.run([0], max_epochs=6)
>>> os.listdir('/tmp/models')
['myprefix_mymodel_4.pth', 'myprefix_mymodel_6.pth']
"""
def __init__(
self,
dirname,
filename_prefix,
save_interval=None,
score_function=None,
score_name=None,
n_saved=1,
atomic=True,
require_empty=True,
create_dir=True,
save_as_state_dict=True,
):
self._dirname = os.path.expanduser(dirname)
self._fname_prefix = filename_prefix
self._n_saved = n_saved
self._save_interval = save_interval
self._score_function = score_function
self._score_name = score_name
self._atomic = atomic
self._saved = [] # list of tuples (priority, saved_objects)
self._iteration = 0
self._save_as_state_dict = save_as_state_dict
if not (save_interval is None) ^ (score_function is None):
raise ValueError(
"Exactly one of `save_interval`, or `score_function` "
"arguments must be provided."
)
if score_function is None and score_name is not None:
raise ValueError(
"If `score_name` is provided, then `score_function` "
"should be also provided."
)
if create_dir:
if not os.path.exists(dirname):
os.makedirs(dirname)
# Ensure that dirname exists
if not os.path.exists(dirname):
raise ValueError("Directory path '{}' is not found.".format(dirname))
if require_empty:
matched = [
fname
for fname in os.listdir(dirname)
if fname.startswith(self._fname_prefix)
]
if len(matched) > 0:
raise ValueError(
"Files prefixed with {} are already present "
"in the directory {}. If you want to use this "
"directory anyway, pass `require_empty=False`."
"".format(filename_prefix, dirname)
)
def _save(self, obj, path):
if not self._atomic:
self._internal_save(obj, path)
else:
tmp = tempfile.NamedTemporaryFile(delete=False, dir=self._dirname)
try:
self._internal_save(obj, tmp.file)
except BaseException:
tmp.close()
os.remove(tmp.name)
raise
else:
tmp.close()
os.rename(tmp.name, path)
def _internal_save(self, obj, path):
if not self._save_as_state_dict:
torch.save(obj, path)
else:
if not hasattr(obj, "state_dict") or not callable(obj.state_dict):
raise ValueError("Object should have `state_dict` method.")
torch.save(obj.state_dict(), path)
def __call__(self, engine, to_save):
if len(to_save) == 0:
raise RuntimeError("No objects to checkpoint found.")
self._iteration += 1
if self._score_function is not None:
priority = self._score_function(engine)
else:
priority = self._iteration
if (self._iteration % self._save_interval) != 0:
return
if (len(self._saved) < self._n_saved) or (self._saved[0][0] < priority):
saved_objs = []
suffix = ""
if self._score_name is not None:
suffix = "_{}={:.7}".format(self._score_name, abs(priority))
for name, obj in to_save.items():
fname = "{}_{}_{}{}.pth".format(
self._fname_prefix, name, self._iteration, suffix
)
path = os.path.join(self._dirname, fname)
self._save(obj=obj, path=path)
saved_objs.append(path)
self._saved.append((priority, saved_objs))
self._saved.sort(key=lambda item: item[0])
if len(self._saved) > self._n_saved:
_, paths = self._saved.pop(0)
for p in paths:
os.remove(p)
[docs]class FlexibleModelCheckpoint(ModelCheckpoint):
[docs] def __init__(
self,
dirname,
filename_prefix,
offset_hours=0,
filename_format=None,
suffix_format=None,
*args,
**kwargs
):
if "%" in filename_prefix:
filename_prefix = get_timestamp(
fmt=filename_prefix, offset_hours=offset_hours
)
super().__init__(dirname, filename_prefix, *args, **kwargs)
if not callable(filename_format):
if isinstance(filename_format, str):
format_str = filename_format
else:
format_str = "{}_{}_{:06d}{}.pth"
def filename_format(filename_prefix, name, step_number, suffix):
return format_str.format(filename_prefix, name, step_number, suffix)
self._filename_format = filename_format
if not callable(suffix_format):
if isinstance(suffix_format, str):
suffix_str = suffix_format
else:
suffix_str = "_{}_{:.7}"
def suffix_format(score_name, abs_priority):
return suffix_str.format(score_name, abs_priority)
self._suffix_format = suffix_format
def __call__(self, engine, to_save):
if len(to_save) == 0:
raise RuntimeError("No objects to checkpoint found.")
self._iteration += 1
if self._score_function is not None:
priority = self._score_function(engine)
else:
priority = self._iteration
if (self._iteration % self._save_interval) != 0:
return
if (len(self._saved) < self._n_saved) or (self._saved[0][0] < priority):
saved_objs = []
suffix = ""
if self._score_name is not None:
suffix = self._suffix_format(self._score_name, abs(priority))
for name, obj in to_save.items():
fname = self._filename_format(
self._fname_prefix, name, self._iteration, suffix
)
path = os.path.join(self._dirname, fname)
self._save(obj=obj, path=path)
saved_objs.append(path)
self._saved.append((priority, saved_objs))
self._saved.sort(key=lambda item: item[0])
if len(self._saved) > self._n_saved:
_, paths = self._saved.pop(0)
for p in paths:
os.remove(p)
def get_timestamp(fmt="%Y-%m-%dT%H:%M:%S", offset_hours=0):
return (datetime.now() + timedelta(hours=offset_hours)).strftime(fmt)