예제 #1
0
    def __init__(
            self,
            models,
            optimizers,
            max_epochs,
            extensions,
            out_dir,
            writer,
            stop_trigger=None):
        if extensions is None:
            extensions = []
        if stop_trigger is None:
            self._stop_trigger = trigger_module.get_trigger(
                (max_epochs, 'epoch'))
        else:
            self._stop_trigger = stop_trigger
        if writer is None:
            writer = writing.SimpleWriter(out_dir=out_dir)
        # triggers are stateful, so we need to make a copy for internal use
        self._internal_stop_trigger = copy.deepcopy(self._stop_trigger)
        self.observation = {}
        self._out = out_dir
        self.writer = writer
        self.reporter = Reporter()

        if not isinstance(models, dict):
            if not isinstance(models, torch.nn.Module):
                raise ValueError(
                    'model must be an instance of dict or toch.nn.Module')
            self._models = {'main': models}
        else:
            self._models = models
        if not isinstance(optimizers, dict):
            # TODO(ecastill) Optimizer type is not checked because of tests
            # using mocks and other classes
            self._optimizers = {'main': optimizers}
        else:
            self._optimizers = optimizers

        for name, model in self._models.items():
            self.reporter.add_observer(name, model)
            self.reporter.add_observers(
                name, model.named_modules())
        self.max_epochs = max_epochs
        self._start_iteration = 0
        # Defer!
        self._start_time = None
        self._extensions = collections.OrderedDict()
        for ext in extensions:
            self.extend(ext)

        # Initialize the writer
        self.writer.initialize(self.out)
예제 #2
0
class _BaseExtensionsManager:
    """
    Keeps track of the extensions and the current status
    """
    def __init__(self,
                 models,
                 optimizers,
                 max_epochs,
                 extensions,
                 out_dir,
                 writer,
                 stop_trigger=None):
        if extensions is None:
            extensions = []
        if stop_trigger is None:
            self._stop_trigger = trigger_module.get_trigger(
                (max_epochs, 'epoch'))
        else:
            self._stop_trigger = stop_trigger
        if writer is None:
            writer = writing.SimpleWriter(out_dir=out_dir)
        # triggers are stateful, so we need to make a copy for internal use
        self._internal_stop_trigger = copy.deepcopy(self._stop_trigger)
        self.observation = {}
        self._out = out_dir
        self.writer = writer
        self.reporter = Reporter()

        if not isinstance(models, dict):
            if not isinstance(models, torch.nn.Module):
                raise ValueError(
                    'model must be an instance of dict or toch.nn.Module')
            self._models = {'main': models}
        else:
            self._models = models
        if not isinstance(optimizers, dict):
            # TODO(ecastill) Optimizer type is not checked because of tests
            # using mocks and other classes
            self._optimizers = {'main': optimizers}
        else:
            self._optimizers = optimizers

        for name, model in self._models.items():
            self.reporter.add_observer(name, model)
            self.reporter.add_observers(name, model.named_modules())
        self.max_epochs = max_epochs
        self._start_iteration = 0
        # Defer!
        self._start_time = None
        self._extensions = collections.OrderedDict()
        for ext in extensions:
            self.extend(ext)

        # Initialize the writer
        self.writer.initialize(self.out)

    @property
    def models(self):
        return self._models

    @property
    def optimizers(self):
        return self._optimizers

    @property
    def elapsed_time(self):
        return _get_time() - self._start_time

    @property
    def is_before_training(self):
        return self.iteration == 0

    @property
    def epoch(self):
        return self.iteration // self._iters_per_epoch

    @property
    def epoch_detail(self):
        return self.iteration / self._iters_per_epoch

    @property
    def stop_trigger(self):
        # Trigger is stateful, we close the extensions the first time
        # it evaluates to True, as it won't do it again
        return self._stop_trigger(self)

    @property
    def out(self):
        if self.writer.out_dir is not None:
            return self.writer.out_dir
        else:
            return self._out

    @property
    def updater(self):
        warnings.warn(
            'The `updater` attribute has been deprecated in v0.3.0.'
            ' Use `iteration`, `epoch`, and `epoch_detail` attributes in'
            ' `ExtensionsManager` instead of attributes under `updater`.'
            ' You may also need to update the filename template specified to'
            ' snapshot extensions (e.g., from '
            '`snapshot_iter_{.updater.iteration}` to'
            ' `snapshot_iter_{.iteration}`).', DeprecationWarning)
        return self

    def _prepare_for_training(self, start_iteration, iters_per_epoch):
        self.iteration = start_iteration
        self._iters_per_epoch = iters_per_epoch

    def start_extensions(self):
        exts = self._extensions
        extension_order = sorted(exts.keys(),
                                 key=lambda name: exts[name].priority,
                                 reverse=True)
        self.extensions = [(name, exts[name]) for name in extension_order]

        # invoke initializer of each extension
        for _, entry in self.extensions:
            initializer = getattr(entry.extension, 'initialize', None)
            finished = getattr(entry.trigger, 'finished', False)
            if initializer and not finished:
                initializer(self)

        # call extensions before training loop
        self.observation = {}
        with self.reporter.scope(self.observation):
            for name, entry in self.extensions:
                if entry.call_before_training:
                    entry.extension(self)

    def extend(self,
               extension,
               name=None,
               trigger=None,
               priority=None,
               *,
               call_before_training=False,
               **kwargs):
        """Registers an extension to the manager.

        :class:`Extension` is a callable object which is called after each
        update unless the corresponding trigger object decides to skip the
        iteration. The order of execution is determined by priorities:
        extensions with higher priorities are called earlier in each iteration.
        Extensions with the same priority are invoked in the order of
        registrations.

        If two or more extensions with the same name are registered, suffixes
        are added to the names of the second to last extensions. The suffix is
        ``_N`` where N is the ordinal of the extensions.

        See :class:`Extension` for the interface of extensions.

        Args:
            extension: Extension to register.
            name (str): Name of the extension. If it is omitted, the
                :attr:`Extension.name` attribute of the extension is used or
                the :attr:`Extension.default_name` attribute of the extension
                if `name` is is set to `None` or is undefined.
                Note that the name would be suffixed by an ordinal in case of
                duplicated names as explained above.
            trigger (tuple or Trigger): Trigger object that determines when to
                invoke the extension. If it is ``None``, ``extension.trigger``
                is used instead. If it is ``None`` and the extension does not
                have the trigger attribute, the extension is triggered at every
                iteration by default. If the trigger is not callable, it is
                passed to :class:`IntervalTrigger` to build an interval
                trigger.
            call_before_training (bool): Flag to call extension before
                training. Default is ``False``.
            priority (int): Invocation priority of the extension. Extensions
                are invoked in the descending order of priorities in each
                iteration. If this is ``None``, ``extension.priority`` is used
                instead.

        """
        if name is None:
            name = getattr(extension, 'name', None)
            if name is None:
                name = getattr(extension, 'default_name', None)
                if name is None:
                    name = getattr(extension, '__name__', None)
                    if name is None:
                        raise TypeError('name is not given for the extension')
        if name == 'training':
            raise ValueError(
                'the name "training" is prohibited as an extension name')

        if trigger is None:
            trigger = getattr(extension, 'trigger', (1, 'iteration'))
        trigger = trigger_module.get_trigger(trigger)

        if priority is None:
            priority = getattr(extension, 'priority',
                               extension_module.PRIORITY_READER)

        modified_name = name
        ordinal = 0
        while modified_name in self._extensions:
            ordinal += 1
            modified_name = '%s_%d' % (name, ordinal)

        extension.name = modified_name
        self._extensions[modified_name] = _ExtensionEntry(
            extension, priority, trigger, call_before_training)

    def get_extension(self, name):
        """Returns the extension of a given name.

        Args:
            name (str): Name of the extension.

        Returns:
            Extension.

        """
        extensions = self._extensions
        if name in extensions:
            return extensions[name].extension
        else:
            raise ValueError('extension %s not found' % name)

    def run_extensions(self):
        for name, entry in self.extensions:
            if entry.trigger(self):
                entry.extension(self)

    def _finalize_extensions(self):
        for _, entry in self.extensions:
            # Some mock objects for tests give errors
            # if we use `getattr`
            try:
                if getattr(entry.extension, 'finalize'):
                    entry.extension.finalize()
            except AttributeError:
                pass

    def state_dict(self, *, transform_models=lambda n, x: x):
        """
        transform_models is a function that apply a transformation
        to a model.

        When using a `torch.nn.DataParallel` model, if we want
        to save only yhe `.module` object, state_dict can be
        called as follows
        state_dict(transform_models=lambda n, x: x.module)
        """
        to_save = {}
        to_save['_start_iteration'] = self.iteration
        # Save manager status ?
        to_save['models'] = {
            name: transform_models(name, self._models[name]).state_dict()
            for name in self._models
        }
        to_save['optimizers'] = {
            name: self._optimizers[name].state_dict()
            for name in self._optimizers
        }
        to_save['extensions'] = {
            name: self._extensions[name].state_dict()
            for name in self._extensions
        }
        return to_save

    def load_state_dict(self, to_load, *, transform_models=lambda n, x: x):
        """
        transform_models is a function that apply a transformation
        to a model.

        When using a `torch.nn.DataParallel` model, if we want
        to load a model with the `torch.nn.DataParallel` applied
        load_state_dict(
            state, transform_models=lambda n, x: torch.nn.DataParallel(x))
        """
        self._start_iteration = to_load['_start_iteration']
        self.iteration = self._start_iteration
        for name in self._models:
            # TODO(ecastill) map_loc when loading the model and DDP check
            self._models[name].load_state_dict(to_load['models'][name])
            self._models[name] = transform_models(name, self._models[name])

        for name in self._optimizers:
            self._optimizers[name].load_state_dict(to_load['optimizers'][name])

        for name in self._extensions:
            self._extensions[name].load_state_dict(to_load['extensions'][name])