def __init__( self, models, optimizers, max_epochs, extensions, out_dir, writer, stop_trigger=None): if extensions is None: extensions = [] if stop_trigger is None: self._stop_trigger = trigger_module.get_trigger( (max_epochs, 'epoch')) else: self._stop_trigger = stop_trigger if writer is None: writer = writing.SimpleWriter(out_dir=out_dir) # triggers are stateful, so we need to make a copy for internal use self._internal_stop_trigger = copy.deepcopy(self._stop_trigger) self.observation = {} self._out = out_dir self.writer = writer self.reporter = Reporter() if not isinstance(models, dict): if not isinstance(models, torch.nn.Module): raise ValueError( 'model must be an instance of dict or toch.nn.Module') self._models = {'main': models} else: self._models = models if not isinstance(optimizers, dict): # TODO(ecastill) Optimizer type is not checked because of tests # using mocks and other classes self._optimizers = {'main': optimizers} else: self._optimizers = optimizers for name, model in self._models.items(): self.reporter.add_observer(name, model) self.reporter.add_observers( name, model.named_modules()) self.max_epochs = max_epochs self._start_iteration = 0 # Defer! self._start_time = None self._extensions = collections.OrderedDict() for ext in extensions: self.extend(ext) # Initialize the writer self.writer.initialize(self.out)
class _BaseExtensionsManager: """ Keeps track of the extensions and the current status """ def __init__(self, models, optimizers, max_epochs, extensions, out_dir, writer, stop_trigger=None): if extensions is None: extensions = [] if stop_trigger is None: self._stop_trigger = trigger_module.get_trigger( (max_epochs, 'epoch')) else: self._stop_trigger = stop_trigger if writer is None: writer = writing.SimpleWriter(out_dir=out_dir) # triggers are stateful, so we need to make a copy for internal use self._internal_stop_trigger = copy.deepcopy(self._stop_trigger) self.observation = {} self._out = out_dir self.writer = writer self.reporter = Reporter() if not isinstance(models, dict): if not isinstance(models, torch.nn.Module): raise ValueError( 'model must be an instance of dict or toch.nn.Module') self._models = {'main': models} else: self._models = models if not isinstance(optimizers, dict): # TODO(ecastill) Optimizer type is not checked because of tests # using mocks and other classes self._optimizers = {'main': optimizers} else: self._optimizers = optimizers for name, model in self._models.items(): self.reporter.add_observer(name, model) self.reporter.add_observers(name, model.named_modules()) self.max_epochs = max_epochs self._start_iteration = 0 # Defer! self._start_time = None self._extensions = collections.OrderedDict() for ext in extensions: self.extend(ext) # Initialize the writer self.writer.initialize(self.out) @property def models(self): return self._models @property def optimizers(self): return self._optimizers @property def elapsed_time(self): return _get_time() - self._start_time @property def is_before_training(self): return self.iteration == 0 @property def epoch(self): return self.iteration // self._iters_per_epoch @property def epoch_detail(self): return self.iteration / self._iters_per_epoch @property def stop_trigger(self): # Trigger is stateful, we close the extensions the first time # it evaluates to True, as it won't do it again return self._stop_trigger(self) @property def out(self): if self.writer.out_dir is not None: return self.writer.out_dir else: return self._out @property def updater(self): warnings.warn( 'The `updater` attribute has been deprecated in v0.3.0.' ' Use `iteration`, `epoch`, and `epoch_detail` attributes in' ' `ExtensionsManager` instead of attributes under `updater`.' ' You may also need to update the filename template specified to' ' snapshot extensions (e.g., from ' '`snapshot_iter_{.updater.iteration}` to' ' `snapshot_iter_{.iteration}`).', DeprecationWarning) return self def _prepare_for_training(self, start_iteration, iters_per_epoch): self.iteration = start_iteration self._iters_per_epoch = iters_per_epoch def start_extensions(self): exts = self._extensions extension_order = sorted(exts.keys(), key=lambda name: exts[name].priority, reverse=True) self.extensions = [(name, exts[name]) for name in extension_order] # invoke initializer of each extension for _, entry in self.extensions: initializer = getattr(entry.extension, 'initialize', None) finished = getattr(entry.trigger, 'finished', False) if initializer and not finished: initializer(self) # call extensions before training loop self.observation = {} with self.reporter.scope(self.observation): for name, entry in self.extensions: if entry.call_before_training: entry.extension(self) def extend(self, extension, name=None, trigger=None, priority=None, *, call_before_training=False, **kwargs): """Registers an extension to the manager. :class:`Extension` is a callable object which is called after each update unless the corresponding trigger object decides to skip the iteration. The order of execution is determined by priorities: extensions with higher priorities are called earlier in each iteration. Extensions with the same priority are invoked in the order of registrations. If two or more extensions with the same name are registered, suffixes are added to the names of the second to last extensions. The suffix is ``_N`` where N is the ordinal of the extensions. See :class:`Extension` for the interface of extensions. Args: extension: Extension to register. name (str): Name of the extension. If it is omitted, the :attr:`Extension.name` attribute of the extension is used or the :attr:`Extension.default_name` attribute of the extension if `name` is is set to `None` or is undefined. Note that the name would be suffixed by an ordinal in case of duplicated names as explained above. trigger (tuple or Trigger): Trigger object that determines when to invoke the extension. If it is ``None``, ``extension.trigger`` is used instead. If it is ``None`` and the extension does not have the trigger attribute, the extension is triggered at every iteration by default. If the trigger is not callable, it is passed to :class:`IntervalTrigger` to build an interval trigger. call_before_training (bool): Flag to call extension before training. Default is ``False``. priority (int): Invocation priority of the extension. Extensions are invoked in the descending order of priorities in each iteration. If this is ``None``, ``extension.priority`` is used instead. """ if name is None: name = getattr(extension, 'name', None) if name is None: name = getattr(extension, 'default_name', None) if name is None: name = getattr(extension, '__name__', None) if name is None: raise TypeError('name is not given for the extension') if name == 'training': raise ValueError( 'the name "training" is prohibited as an extension name') if trigger is None: trigger = getattr(extension, 'trigger', (1, 'iteration')) trigger = trigger_module.get_trigger(trigger) if priority is None: priority = getattr(extension, 'priority', extension_module.PRIORITY_READER) modified_name = name ordinal = 0 while modified_name in self._extensions: ordinal += 1 modified_name = '%s_%d' % (name, ordinal) extension.name = modified_name self._extensions[modified_name] = _ExtensionEntry( extension, priority, trigger, call_before_training) def get_extension(self, name): """Returns the extension of a given name. Args: name (str): Name of the extension. Returns: Extension. """ extensions = self._extensions if name in extensions: return extensions[name].extension else: raise ValueError('extension %s not found' % name) def run_extensions(self): for name, entry in self.extensions: if entry.trigger(self): entry.extension(self) def _finalize_extensions(self): for _, entry in self.extensions: # Some mock objects for tests give errors # if we use `getattr` try: if getattr(entry.extension, 'finalize'): entry.extension.finalize() except AttributeError: pass def state_dict(self, *, transform_models=lambda n, x: x): """ transform_models is a function that apply a transformation to a model. When using a `torch.nn.DataParallel` model, if we want to save only yhe `.module` object, state_dict can be called as follows state_dict(transform_models=lambda n, x: x.module) """ to_save = {} to_save['_start_iteration'] = self.iteration # Save manager status ? to_save['models'] = { name: transform_models(name, self._models[name]).state_dict() for name in self._models } to_save['optimizers'] = { name: self._optimizers[name].state_dict() for name in self._optimizers } to_save['extensions'] = { name: self._extensions[name].state_dict() for name in self._extensions } return to_save def load_state_dict(self, to_load, *, transform_models=lambda n, x: x): """ transform_models is a function that apply a transformation to a model. When using a `torch.nn.DataParallel` model, if we want to load a model with the `torch.nn.DataParallel` applied load_state_dict( state, transform_models=lambda n, x: torch.nn.DataParallel(x)) """ self._start_iteration = to_load['_start_iteration'] self.iteration = self._start_iteration for name in self._models: # TODO(ecastill) map_loc when loading the model and DDP check self._models[name].load_state_dict(to_load['models'][name]) self._models[name] = transform_models(name, self._models[name]) for name in self._optimizers: self._optimizers[name].load_state_dict(to_load['optimizers'][name]) for name in self._extensions: self._extensions[name].load_state_dict(to_load['extensions'][name])