def __init__(self,
                 y_keys,
                 x_key='iteration',
                 trigger=(1, 'epoch'),
                 postprocess=None,
                 filename=None,
                 marker='x',
                 grid=True,
                 **kwargs):

        file_name, = argument.parse_kwargs(kwargs, ('file_name', 'plot.png'))
        if filename is None:
            filename = file_name
        del file_name  # avoid accidental use

        _check_available()

        self._x_key = x_key
        if isinstance(y_keys, str):
            y_keys = (y_keys, )

        self._y_keys = y_keys
        self._trigger = trigger_module.get_trigger(trigger)
        self._file_name = filename
        self._marker = marker
        self._grid = grid
        self._postprocess = postprocess
        self._init_summary()
        self._data = {k: [] for k in y_keys}
Exemple #2
0
    def __init__(self,
                 updater,
                 stop_trigger=None,
                 out='result',
                 extensions=None):
        self.updater = updater
        self.stop_trigger = trigger_module.get_trigger(stop_trigger)
        self.observation = {}
        self.out = out
        if extensions is None:
            extensions = []

        reporter = reporter_module.Reporter()
        for name, model in six.iteritems(updater.get_all_models()):
            reporter.add_observer(name, model)
            reporter.add_observers(name, model.named_children())
        self.reporter = reporter

        self._done = False
        self._extensions = collections.OrderedDict()

        self._start_at = None
        self._snapshot_elapsed_time = 0.0
        self._final_elapsed_time = None

        updater.connect_trainer(self)
        for ext in extensions:
            self.extend(ext)
    def __init__(self,
                 keys=None,
                 trigger=(1, 'epoch'),
                 postprocess=None,
                 filename=None,
                 **kwargs):
        self._keys = keys
        self._trigger = trigger_module.get_trigger(trigger)
        self._postprocess = postprocess
        self._log = []

        log_name, = argument.parse_kwargs(
            kwargs,
            ('log_name', 'log'),
        )
        if filename is None:
            filename = log_name
        del log_name  # avoid accidental use
        self._log_name = filename

        self._init_summary()
Exemple #4
0
    def extend(self,
               extension,
               name=None,
               trigger=None,
               priority=None,
               *,
               call_before_training=False,
               **kwargs):
        """Registers an extension to the trainer.

        :class:`Extension` is a callable object which is called after each
        update unless the corresponding trigger object decides to skip the
        iteration. The order of execution is determined by priorities:
        extensions with higher priorities are called earlier in each iteration.
        Extensions with the same priority are invoked in the order of
        registrations.

        If two or more extensions with the same name are registered, suffixes
        are added to the names of the second to last extensions. The suffix is
        ``_N`` where N is the ordinal of the extensions.

        See :class:`Extension` for the interface of extensions.

        Args:
            extension: Extension to register.
            name (str): Name of the extension. If it is omitted, the
                :attr:`Extension.name` attribute of the extension is used or
                the :attr:`Extension.default_name` attribute of the extension
                if `name` is is set to `None` or is undefined.
                Note that the name would be suffixed by an ordinal in case of
                duplicated names as explained above.
            trigger (tuple or Trigger): Trigger object that determines when to
                invoke the extension. If it is ``None``, ``extension.trigger``
                is used instead. If it is ``None`` and the extension does not
                have the trigger attribute, the extension is triggered at every
                iteration by default. If the trigger is not callable, it is
                passed to :class:`IntervalTrigger` to build an interval
                trigger.
            call_before_training (bool): Flag to call extension before
                training. Default is ``False``.
            priority (int): Invocation priority of the extension. Extensions
                are invoked in the descending order of priorities in each
                iteration. If this is ``None``, ``extension.priority`` is used
                instead.

        """
        if kwargs:
            argument.check_unexpected_kwargs(
                kwargs,
                invoke_before_training='invoke_before_training has been '
                'removed since Chainer v2.0.0. Use initializer= instead.')
            argument.assert_kwargs_empty(kwargs)

        if name is None:
            name = getattr(extension, 'name', None)
            if name is None:
                name = getattr(extension, 'default_name', None)
                if name is None:
                    name = getattr(extension, '__name__', None)
                    if name is None:
                        raise TypeError('name is not given for the extension')
        if name == 'training':
            raise ValueError(
                'the name "training" is prohibited as an extension name')

        if trigger is None:
            trigger = getattr(extension, 'trigger', (1, 'iteration'))
        trigger = trigger_module.get_trigger(trigger)

        if priority is None:
            priority = getattr(extension, 'priority',
                               extension_module.PRIORITY_READER)

        modified_name = name
        ordinal = 0
        while modified_name in self._extensions:
            ordinal += 1
            modified_name = '%s_%d' % (name, ordinal)

        extension.name = modified_name
        self._extensions[modified_name] = _ExtensionEntry(
            extension, priority, trigger, call_before_training)