def __init__(self, y_keys, x_key='iteration', trigger=(1, 'epoch'), postprocess=None, filename=None, marker='x', grid=True, **kwargs): file_name, = argument.parse_kwargs(kwargs, ('file_name', 'plot.png')) if filename is None: filename = file_name del file_name # avoid accidental use _check_available() self._x_key = x_key if isinstance(y_keys, str): y_keys = (y_keys, ) self._y_keys = y_keys self._trigger = trigger_module.get_trigger(trigger) self._file_name = filename self._marker = marker self._grid = grid self._postprocess = postprocess self._init_summary() self._data = {k: [] for k in y_keys}
def __init__(self, check_trigger=(1, 'epoch'), monitor='main/loss', patience=None, mode='auto', verbose=False, max_trigger=(100, 'epoch'), **kwargs): # `patients` as an alias of `patience` patients, = argument.parse_kwargs(kwargs, ('patients', None)) if patients is None: if patience is None: patience = 3 else: pass else: if patience is None: patience = patients else: raise TypeError( 'Both \'patience\' and \'patients\' arguments are ' 'specified. \'patients\' is an alias of the former. ' 'Specify only \'patience\'.') self.count = 0 self.patience = patience self.monitor = monitor self.verbose = verbose self.already_warning = False self._max_trigger = util.get_trigger(max_trigger) self._interval_trigger = util.get_trigger(check_trigger) self._init_summary() if mode == 'max': self._compare = operator.gt elif mode == 'min': self._compare = operator.lt else: if 'accuracy' in monitor: self._compare = operator.gt else: self._compare = operator.lt if self._compare == operator.gt: if verbose: print('early stopping: operator is greater') self.best = float('-inf') else: if verbose: print('early stopping: operator is less') self.best = float('inf')
def __init__(self, iterator, target, converter=convert.concat_examples, device=None, eval_hook=None, eval_func=None, **kwargs): progress_bar, = argument.parse_kwargs(kwargs, ('progress_bar', False)) if device is not None: device = torch.device(device) if isinstance(iterator, iterator_module.Iterator): iterator = {'main': iterator} self._iterators = iterator if isinstance(target, module.Module): target = {'main': target} self._targets = target self.converter = converter self.device = device self.eval_hook = eval_hook self.eval_func = eval_func self._progress_bar = progress_bar for key, iter in six.iteritems(iterator): if (isinstance( iter, (iterators.SerialIterator, iterators.MultiprocessIterator, iterators.MultithreadIterator)) and getattr(iter, 'repeat', False)): msg = 'The `repeat` property of the iterator {} ' 'is set to `True`. Typically, the evaluator sweeps ' 'over iterators until they stop, ' 'but as the property being `True`, this iterator ' 'might not stop and evaluation could go into ' 'an infinite loop. ' 'We recommend to check the configuration ' 'of iterators'.format(key) warnings.warn(msg)
def __init__(self, keys=None, trigger=(1, 'epoch'), postprocess=None, filename=None, **kwargs): self._keys = keys self._trigger = trigger_module.get_trigger(trigger) self._postprocess = postprocess self._log = [] log_name, = argument.parse_kwargs( kwargs, ('log_name', 'log'), ) if filename is None: filename = log_name del log_name # avoid accidental use self._log_name = filename self._init_summary()
def snapshot(savefun=None, filename='snapshot_iter_{.updater.iteration}', **kwargs): """snapshot(savefun=None, filename='snapshot_iter_{.updater.iteration}', \ *, target=None, condition=None, writer=None, snapshot_on_error=False, \ n_retains=-1, autoload=False) Returns a trainer extension to take snapshots of the trainer. This extension serializes the trainer object and saves it to the output directory. It is used to support resuming the training loop from the saved state. This extension is called once per epoch by default. To take a snapshot at a different interval, a trigger object specifying the required interval can be passed along with this extension to the `extend()` method of the trainer. The default priority is -100, which is lower than that of most built-in extensions. .. note:: This extension first writes the serialized object to a temporary file and then rename it to the target file name. Thus, if the program stops right before the renaming, the temporary file might be left in the output directory. Args: savefun: Function to save the trainer. It takes two arguments: the output file path and the trainer object. It is :meth:`pytorch_trainer.serializers.save_npz` by default. If ``writer`` is specified, this argument must be ``None``. filename (str): Name of the file into which the trainer is serialized. It can be a format string, where the trainer object is passed to the :meth:`str.format` method. target: Object to serialize. If it is not specified, it will be the trainer object. condition: Condition object. It must be a callable object that returns boolean without any arguments. If it returns ``True``, the snapshot will be done. If not, it will be skipped. The default is a function that always returns ``True``. writer: Writer object. It must be a callable object. See below for the list of built-in writers. If ``savefun`` is other than ``None``, this argument must be ``None``. In that case, a :class:`~pytorch_trainer.training.extensions.snapshot_writers.SimpleWriter` object instantiated with specified ``savefun`` argument will be used. snapshot_on_error (bool): Whether to take a snapshot in case trainer loop has been failed. n_retains (int): Number of snapshot files to retain through the cleanup. Must be a positive integer for any cleanup to take place. Automatic deletion of old snapshots only works when the filename is string. num_retain (int): Same as ``n_retains`` (deprecated). autoload (bool): With this enabled, the extension automatically finds the latest snapshot and loads the data to the target. Automatic loading only works when the filename is a string. It is assumed that snapshots are generated by :func:`pytorch_trainer.serializers.save_npz` . Returns: Snapshot extension object. .. testcode:: :hide: from pytorch_trainer import training class Model(pytorch_trainer.Link): def __call__(self, x): return x train_iter = pytorch_trainer.iterators.SerialIterator([], 1) optimizer = optimizers.SGD().setup(Model()) updater = training.updaters.StandardUpdater( train_iter, optimizer, device=0) trainer = training.Trainer(updater) .. admonition:: Using asynchronous writers By specifying ``writer`` argument, writing operations can be made asynchronous, hiding I/O overhead of snapshots. >>> from chainer.training import extensions >>> writer = extensions.snapshot_writers.ProcessWriter() >>> trainer.extend(extensions.snapshot(writer=writer), \ trigger=(1, 'epoch')) To change the format, such as npz or hdf5, you can pass a saving function as ``savefun`` argument of the writer. >>> from chainer.training import extensions >>> from chainer import serializers >>> writer = extensions.snapshot_writers.ProcessWriter( ... savefun=serializers.save_npz) >>> trainer.extend(extensions.snapshot(writer=writer), \ trigger=(1, 'epoch')) This is the list of built-in snapshot writers. - :class:`pytorch_trainer.training.extensions.snapshot_writers.SimpleWriter` - :class:`pytorch_trainer.training.extensions.snapshot_writers.ThreadWriter` - :class:`pytorch_trainer.training.extensions.snapshot_writers.ProcessWriter` - :class:`pytorch_trainer.training.extensions.snapshot_writers.\ ThreadQueueWriter` - :class:`pytorch_trainer.training.extensions.snapshot_writers.\ ProcessQueueWriter` .. seealso:: - :meth:`pytorch_trainer.training.extensions.snapshot_object` """ if 'num_retain' in kwargs: warnings.warn( 'Argument `num_retain` is deprecated. ' 'Please use `n_retains` instead', DeprecationWarning) kwargs['n_retains'] = kwargs.pop('num_retain') target, condition, writer, snapshot_on_error, n_retains,\ autoload = argument.parse_kwargs( kwargs, ('target', None), ('condition', None), ('writer', None), ('snapshot_on_error', False), ('n_retains', -1), ('autoload', False)) argument.assert_kwargs_empty(kwargs) if savefun is not None and writer is not None: raise TypeError( 'savefun and writer arguments cannot be specified together.') if writer is None: if savefun is None: savefun = torch.save writer = snapshot_writers.SimpleWriter(savefun=savefun) return _Snapshot(target=target, condition=condition, writer=writer, filename=filename, snapshot_on_error=snapshot_on_error, n_retains=n_retains, autoload=autoload)