示例#1
0
    def __init__(self,
                 y_keys,
                 x_key='iteration',
                 trigger=(1, 'epoch'),
                 postprocess=None,
                 filename=None,
                 marker='x',
                 grid=True,
                 **kwargs):

        file_name, = argument.parse_kwargs(kwargs, ('file_name', 'plot.png'))
        if filename is None:
            filename = file_name
        del file_name  # avoid accidental use

        _check_available()

        self._x_key = x_key
        if isinstance(y_keys, str):
            y_keys = (y_keys, )

        self._y_keys = y_keys
        self._trigger = trigger_module.get_trigger(trigger)
        self._file_name = filename
        self._marker = marker
        self._grid = grid
        self._postprocess = postprocess
        self._init_summary()
        self._data = {k: [] for k in y_keys}
示例#2
0
    def __init__(self,
                 check_trigger=(1, 'epoch'),
                 monitor='main/loss',
                 patience=None,
                 mode='auto',
                 verbose=False,
                 max_trigger=(100, 'epoch'),
                 **kwargs):

        # `patients` as an alias of `patience`
        patients, = argument.parse_kwargs(kwargs, ('patients', None))
        if patients is None:
            if patience is None:
                patience = 3
            else:
                pass
        else:
            if patience is None:
                patience = patients
            else:
                raise TypeError(
                    'Both \'patience\' and \'patients\' arguments are '
                    'specified. \'patients\' is an alias of the former. '
                    'Specify only \'patience\'.')

        self.count = 0
        self.patience = patience
        self.monitor = monitor
        self.verbose = verbose
        self.already_warning = False
        self._max_trigger = util.get_trigger(max_trigger)
        self._interval_trigger = util.get_trigger(check_trigger)

        self._init_summary()

        if mode == 'max':
            self._compare = operator.gt

        elif mode == 'min':
            self._compare = operator.lt

        else:
            if 'accuracy' in monitor:
                self._compare = operator.gt

            else:
                self._compare = operator.lt

        if self._compare == operator.gt:
            if verbose:
                print('early stopping: operator is greater')
            self.best = float('-inf')

        else:
            if verbose:
                print('early stopping: operator is less')
            self.best = float('inf')
示例#3
0
    def __init__(self,
                 iterator,
                 target,
                 converter=convert.concat_examples,
                 device=None,
                 eval_hook=None,
                 eval_func=None,
                 **kwargs):
        progress_bar, = argument.parse_kwargs(kwargs, ('progress_bar', False))

        if device is not None:
            device = torch.device(device)

        if isinstance(iterator, iterator_module.Iterator):
            iterator = {'main': iterator}
        self._iterators = iterator

        if isinstance(target, module.Module):
            target = {'main': target}
        self._targets = target

        self.converter = converter
        self.device = device
        self.eval_hook = eval_hook
        self.eval_func = eval_func

        self._progress_bar = progress_bar

        for key, iter in six.iteritems(iterator):
            if (isinstance(
                    iter,
                (iterators.SerialIterator, iterators.MultiprocessIterator,
                 iterators.MultithreadIterator))
                    and getattr(iter, 'repeat', False)):
                msg = 'The `repeat` property of the iterator {} '
                'is set to `True`. Typically, the evaluator sweeps '
                'over iterators until they stop, '
                'but as the property being `True`, this iterator '
                'might not stop and evaluation could go into '
                'an infinite loop. '
                'We recommend to check the configuration '
                'of iterators'.format(key)
                warnings.warn(msg)
示例#4
0
    def __init__(self,
                 keys=None,
                 trigger=(1, 'epoch'),
                 postprocess=None,
                 filename=None,
                 **kwargs):
        self._keys = keys
        self._trigger = trigger_module.get_trigger(trigger)
        self._postprocess = postprocess
        self._log = []

        log_name, = argument.parse_kwargs(
            kwargs,
            ('log_name', 'log'),
        )
        if filename is None:
            filename = log_name
        del log_name  # avoid accidental use
        self._log_name = filename

        self._init_summary()
示例#5
0
def snapshot(savefun=None,
             filename='snapshot_iter_{.updater.iteration}',
             **kwargs):
    """snapshot(savefun=None, filename='snapshot_iter_{.updater.iteration}', \
*, target=None, condition=None, writer=None, snapshot_on_error=False, \
n_retains=-1, autoload=False)

    Returns a trainer extension to take snapshots of the trainer.

    This extension serializes the trainer object and saves it to the output
    directory. It is used to support resuming the training loop from the saved
    state.

    This extension is called once per epoch by default. To take a
    snapshot at a different interval, a trigger object specifying the
    required interval can be passed along with this extension
    to the `extend()` method of the trainer.

    The default priority is -100, which is lower than that of most
    built-in extensions.

    .. note::
       This extension first writes the serialized object to a temporary file
       and then rename it to the target file name. Thus, if the program stops
       right before the renaming, the temporary file might be left in the
       output directory.

    Args:
        savefun: Function to save the trainer. It takes two arguments: the
            output file path and the trainer object.
            It is :meth:`pytorch_trainer.serializers.save_npz` by default.
            If ``writer`` is specified, this argument must be ``None``.
        filename (str): Name of the file into which the trainer is serialized.
            It can be a format string, where the trainer object is passed to
            the :meth:`str.format` method.
        target: Object to serialize. If it is not specified, it will
            be the trainer object.
        condition: Condition object. It must be a callable object that returns
            boolean without any arguments. If it returns ``True``, the snapshot
            will be done.
            If not, it will be skipped. The default is a function that always
            returns ``True``.
        writer: Writer object.
            It must be a callable object.
            See below for the list of built-in writers.
            If ``savefun`` is other than ``None``, this argument must be
            ``None``. In that case, a
            :class:`~pytorch_trainer.training.extensions.snapshot_writers.SimpleWriter`
            object instantiated with specified ``savefun`` argument will be
            used.
        snapshot_on_error (bool): Whether to take a snapshot in case trainer
            loop has been failed.
        n_retains (int): Number of snapshot files to retain
            through the cleanup. Must be a positive integer for any cleanup to
            take place. Automatic deletion of old snapshots only works when the
            filename is string.
        num_retain (int): Same as ``n_retains`` (deprecated).
        autoload (bool): With this enabled, the extension
            automatically finds the latest snapshot and loads the data
            to the target.  Automatic loading only works when the
            filename is a string. It is assumed that snapshots are generated
            by :func:`pytorch_trainer.serializers.save_npz` .

    Returns:
        Snapshot extension object.

    .. testcode::
       :hide:

       from pytorch_trainer import training
       class Model(pytorch_trainer.Link):
           def __call__(self, x):
               return x
       train_iter = pytorch_trainer.iterators.SerialIterator([], 1)
       optimizer = optimizers.SGD().setup(Model())
       updater = training.updaters.StandardUpdater(
           train_iter, optimizer, device=0)
       trainer = training.Trainer(updater)

    .. admonition:: Using asynchronous writers

        By specifying ``writer`` argument, writing operations can be made
        asynchronous, hiding I/O overhead of snapshots.

        >>> from chainer.training import extensions
        >>> writer = extensions.snapshot_writers.ProcessWriter()
        >>> trainer.extend(extensions.snapshot(writer=writer), \
trigger=(1, 'epoch'))

        To change the format, such as npz or hdf5, you can pass a saving
        function as ``savefun`` argument of the writer.

        >>> from chainer.training import extensions
        >>> from chainer import serializers
        >>> writer = extensions.snapshot_writers.ProcessWriter(
        ...     savefun=serializers.save_npz)
        >>> trainer.extend(extensions.snapshot(writer=writer), \
trigger=(1, 'epoch'))

    This is the list of built-in snapshot writers.

        - :class:`pytorch_trainer.training.extensions.snapshot_writers.SimpleWriter`
        - :class:`pytorch_trainer.training.extensions.snapshot_writers.ThreadWriter`
        - :class:`pytorch_trainer.training.extensions.snapshot_writers.ProcessWriter`
        - :class:`pytorch_trainer.training.extensions.snapshot_writers.\
ThreadQueueWriter`
        - :class:`pytorch_trainer.training.extensions.snapshot_writers.\
ProcessQueueWriter`

    .. seealso::

        - :meth:`pytorch_trainer.training.extensions.snapshot_object`
    """
    if 'num_retain' in kwargs:
        warnings.warn(
            'Argument `num_retain` is deprecated. '
            'Please use `n_retains` instead', DeprecationWarning)
        kwargs['n_retains'] = kwargs.pop('num_retain')

    target, condition, writer, snapshot_on_error, n_retains,\
        autoload = argument.parse_kwargs(
            kwargs,
            ('target', None), ('condition', None), ('writer', None),
            ('snapshot_on_error', False), ('n_retains', -1),
            ('autoload', False))
    argument.assert_kwargs_empty(kwargs)

    if savefun is not None and writer is not None:
        raise TypeError(
            'savefun and writer arguments cannot be specified together.')

    if writer is None:
        if savefun is None:
            savefun = torch.save
        writer = snapshot_writers.SimpleWriter(savefun=savefun)

    return _Snapshot(target=target,
                     condition=condition,
                     writer=writer,
                     filename=filename,
                     snapshot_on_error=snapshot_on_error,
                     n_retains=n_retains,
                     autoload=autoload)