def __init__( self, handler: 'handler_module.BaseHandler', *, evaluator: Optional[Union['Evaluator', Tuple['Evaluator', TriggerLike]]], models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]], **kwargs: Any, ): self.handler = handler self._manager: Optional['training.ExtensionsManager'] = None # The followings are used when setting up a manager instance if not isinstance(models, dict): if not isinstance(models, torch.nn.Module): raise ValueError( 'model must be an instance of dict or toch.nn.Module') self._models = {'main': models} else: self._models = models self._kwargs = kwargs self._extensions: List[ # list of (args, kwargs) Tuple[Tuple['training.Extension', Optional[str], 'TriggerLike', Optional[int]], Dict[str, Any]]] = [] self._manager_state: Optional[Dict[str, Any]] = None if isinstance(evaluator, tuple): self.evaluator: Optional['Evaluator'] = None self.evaluator, trigger = evaluator self.evaluator_trigger = trigger_module.get_trigger(trigger) else: self.evaluator = evaluator self.evaluator_trigger = trigger_module.get_trigger((1, 'epoch')) self.val_loader = None
def __init__(self, check_trigger=(1, 'epoch'), monitor='main/loss', patience=None, mode='auto', verbose=False, max_trigger=(100, 'epoch'), **kwargs): # `patients` as an alias of `patience` patients = kwargs.get('patients', None) if patients is None: if patience is None: patience = 3 else: pass else: if patience is None: patience = patients else: raise TypeError( 'Both \'patience\' and \'patients\' arguments are ' 'specified. \'patients\' is an alias of the former. ' 'Specify only \'patience\'.') self.count = 0 self.patience = patience self.monitor = monitor self.verbose = verbose self.already_warning = False self._max_trigger = trigger.get_trigger(max_trigger) self._interval_trigger = trigger.get_trigger(check_trigger) self._init_summary() if mode == 'max': self._compare = operator.gt elif mode == 'min': self._compare = operator.lt else: if 'accuracy' in monitor: self._compare = operator.gt else: self._compare = operator.lt if self._compare == operator.gt: if verbose: print('early stopping: operator is greater') self.best = float('-inf') else: if verbose: print('early stopping: operator is less') self.best = float('inf')
def __init__( self, models: Union[torch.nn.Module, Dict[str, torch.nn.Module]], optimizers: Union[torch.optim.Optimizer, Dict[str, torch.optim.Optimizer]], max_epochs: int, extensions: Optional[List['extension_module.ExtensionLike']], out_dir: str, writer: Optional[writing.Writer], stop_trigger: 'trigger_module.TriggerLike' = None ) -> None: if extensions is None: extensions = [] if stop_trigger is None: self._stop_trigger = trigger_module.get_trigger( (max_epochs, 'epoch')) else: self._stop_trigger = trigger_module.get_trigger( stop_trigger) if writer is None: writer = writing.SimpleWriter(out_dir=out_dir) # triggers are stateful, so we need to make a copy for internal use self._internal_stop_trigger = copy.deepcopy(self._stop_trigger) self.observation: Dict[str, reporting.ReportValue] = {} self._out = out_dir self.writer = writer self.reporter = reporting.Reporter() self._start_extensions_called = False if not isinstance(models, dict): if not isinstance(models, torch.nn.Module): raise ValueError( 'model must be an instance of dict or toch.nn.Module') self._models = {'main': models} else: self._models = models if not isinstance(optimizers, dict): # TODO(ecastill) Optimizer type is not checked because of tests # using mocks and other classes self._optimizers = {'main': optimizers} else: self._optimizers = optimizers for name, model in self._models.items(): self.reporter.add_observer(name, model) self.reporter.add_observers( name, model.named_modules()) self.max_epochs = max_epochs self._start_iteration = 0 # Defer! self._start_time: Optional[float] = None self._iters_per_epoch: Optional[int] = None self._extensions: Dict[str, _ExtensionEntry] = collections.OrderedDict() for ext in extensions: self.extend(ext) # Initialize the writer self.writer.initialize(self.out)
def __init__(self, handler, *, evaluator, **kwargs): super().__init__(handler, **kwargs) if type(evaluator) is tuple: self.evaluator, trigger = evaluator self.evaluator_trigger = trigger_module.get_trigger(trigger) else: self.evaluator = evaluator self.evaluator_trigger = trigger_module.get_trigger((1, 'epoch')) self.val_loader = None
def __init__( self, *, trigger=None, compare_fn=_default_comparer, concurrency=None, outputs=True, params=False, ): """A class for comparison of iteration outputs and model parameters. This class is mainly used to compare results between different devices. Args: trigger (Trigger): Trigger object that determines when to compare values. compare_fn (function): Comparison function. Default is ``get_default_comparer()``. concurrency (int, optional): The upper bound limit on the number of workers that run concurrently. If ``None``, inferred from the size of ``engines``. outputs (tuple of str or bool): A set of keys of output dict to compare. params (tuple of str or bool): A set of keys of model parameters to compare. Examples: >>> trainer_cpu = ppe.engine.create_trainer( model, optimizer, 1, device='cpu') >>> trainer_gpu = ppe.engine.create_trainer( model, optimizer, 1, device='cuda:0') >>> comp = ppe.utils.comparer.Comparer() >>> comp.add_engine("cpu", engine_cpu, train_1, eval_1) >>> comp.add_engine("gpu", engine_gpu, train_2, eval_2) >>> comp.compare() """ self._engine_type = None self._engines = collections.OrderedDict() self._compare_fn = compare_fn self._targets = {} self._output_keys = outputs self._param_keys = params self._finalized = False self._concurrency = concurrency # Upper limit of semaphore size self._semaphore = None # Sempaphore for training step execution self._barrier = None # Synchronizes iteration timing self._report_lock = threading.Lock() # Locks `Comparer._add_target` if trigger is None: self._trigger = trigger_module.get_trigger((1, "epoch")) else: self._engine_type = _trainer.Trainer self._trigger = trigger_module.get_trigger(trigger)
def __init__(self, models, optimizers, max_epochs, extensions, out_dir, writer, stop_trigger=None): if extensions is None: extensions = [] if stop_trigger is None: self._stop_trigger = trigger_module.get_trigger( (max_epochs, 'epoch')) else: self._stop_trigger = trigger_module.get_trigger(stop_trigger) if writer is None: writer = writing.SimpleWriter(out_dir=out_dir) # triggers are stateful, so we need to make a copy for internal use self._internal_stop_trigger = copy.deepcopy(self._stop_trigger) self.observation = {} self._out = out_dir self.writer = writer self.reporter = Reporter() if not isinstance(models, dict): if not isinstance(models, torch.nn.Module): raise ValueError( 'model must be an instance of dict or toch.nn.Module') self._models = {'main': models} else: self._models = models if not isinstance(optimizers, dict): # TODO(ecastill) Optimizer type is not checked because of tests # using mocks and other classes self._optimizers = {'main': optimizers} else: self._optimizers = optimizers for name, model in self._models.items(): self.reporter.add_observer(name, model) self.reporter.add_observers(name, model.named_modules()) self.max_epochs = max_epochs self._start_iteration = 0 # Defer! self._start_time = None self._extensions = collections.OrderedDict() for ext in extensions: self.extend(ext) # Initialize the writer self.writer.initialize(self.out)
def __init__(self, keys=None, trigger=(1, 'epoch'), postprocess=None, filename=None, append=False, format=None, **kwargs): self._keys = keys self._trigger = trigger_module.get_trigger(trigger) self._postprocess = postprocess self._log = [] # When using a writer, it needs to have a savefun defined # to deal with a string. self._writer = kwargs.get('writer', None) log_name = kwargs.get('log_name', 'log') if filename is None: filename = log_name del log_name # avoid accidental use self._log_name = filename if format is None and filename is not None: if filename.endswith('.jsonl'): format = 'json-lines' elif filename.endswith('.yaml'): format = 'yaml' else: format = 'json' self._append = append self._format = format self._init_summary()
def __init__(self, y_keys, x_key='iteration', trigger=(1, 'epoch'), postprocess=None, filename=None, marker='x', grid=True, **kwargs): file_name = kwargs.get('file_name', 'plot.png') if filename is None: filename = file_name del file_name # avoid accidental use _check_available() self._x_key = x_key if isinstance(y_keys, str): y_keys = (y_keys, ) self._y_keys = y_keys self._trigger = trigger_module.get_trigger(trigger) self._file_name = filename self._marker = marker self._grid = grid self._postprocess = postprocess self._init_summary() self._data = {k: [] for k in y_keys} self._writer = kwargs.get('writer', None)
def __init__(self, links, statistics='default', report_params=True, report_grads=True, prefix=None, trigger=(1, 'epoch'), skip_nan_params=False): if not isinstance(links, (list, tuple)): links = links, self._links = links if statistics is None: statistics = {} elif statistics == 'default': statistics = self.default_statistics self._statistics = dict(statistics) attrs = [] if report_params: attrs.append('data') if report_grads: attrs.append('grad') self._attrs = attrs self._prefix = prefix self._trigger = trigger_module.get_trigger(trigger) self._summary = reporting.DictSummary() self._skip_nan_params = skip_nan_params
def __init__( self, y_keys: Union[Iterable[str], str], x_key: str = 'iteration', trigger: trigger_module.TriggerLike = (1, 'epoch'), postprocess: Any = None, filename: Optional[str] = None, marker: str = 'x', grid: bool = True, **kwargs: Any, ): file_name = kwargs.get('file_name', 'plot.png') if filename is None: filename = file_name del file_name # avoid accidental use _check_available() self._x_key = x_key if isinstance(y_keys, str): y_keys = (y_keys, ) self._y_keys = y_keys self._trigger = trigger_module.get_trigger(trigger) self._file_name = filename self._marker = marker self._grid = grid self._postprocess = postprocess self._init_summary() self._data: Dict[str, List[Tuple[Any, Any]]] = {k: [] for k in y_keys} self._writer = kwargs.get('writer', None)
def __init__(self, scheduler, *, stepper=_default_stepper, trigger=(1, 'epoch')): self.scheduler = scheduler self.trigger = trigger_module.get_trigger(trigger) self.stepper = stepper
def __init__(self, targets, max_sample_size=1000, report_data=True, report_grad=True, plot_mean=True, plot_std=True, percentile_sigmas=( 0, 0.13, 2.28, 15.87, 50, 84.13, 97.72, 99.87, 100), trigger=(1, 'epoch'), filename=None, figsize=None, marker=None, grid=True, **kwargs): _check_available() file_name = kwargs.get('file_name', 'statistics.png') if filename is None: filename = file_name del file_name # avoid accidental use self._vars = _unpack_variables(targets) if not self._vars: raise ValueError( 'Need at least one variables for which to collect statistics.' '\nActual: 0 <= 0') if not any((plot_mean, plot_std, bool(percentile_sigmas))): raise ValueError('Nothing to plot') self._keys = [] if report_data: self._keys.append('data') if report_grad: self._keys.append('grad') self._report_data = report_data self._report_grad = report_grad self._statistician = Statistician( collect_mean=plot_mean, collect_std=plot_std, percentile_sigmas=percentile_sigmas) self._plot_mean = plot_mean self._plot_std = plot_std self._plot_percentile = bool(percentile_sigmas) self._trigger = trigger_module.get_trigger(trigger) self._filename = filename self._figsize = figsize self._marker = marker self._grid = grid self._writer = kwargs.get('writer', None) if not self._plot_percentile: n_percentile = 0 else: if not isinstance(percentile_sigmas, (list, tuple)): n_percentile = 1 # scalar, single percentile else: n_percentile = len(percentile_sigmas) self._data_shape = ( len(self._keys), int(plot_mean) + int(plot_std) + n_percentile) self._samples = Reservoir(max_sample_size, data_shape=self._data_shape)
def __init__( self, key: str, compare: Callable[[float, float], bool], trigger: 'TriggerLike' = (1, 'epoch'), ) -> None: self._key = key self._best_value: Optional[float] = None self._interval_trigger = trigger_module.get_trigger(trigger) self._init_summary() self._compare = compare
def __init__( self, scheduler: Any, *, stepper: Any = _default_stepper, trigger: trigger_module.TriggerLike = (1, 'epoch'), is_async: bool = True, ) -> None: self.scheduler = scheduler self.trigger = trigger_module.get_trigger(trigger) self.stepper = stepper self.is_async = is_async
def __init__(self, numerator_key, denominator_key, result_key, trigger=(1, 'epoch')): self._trigger = trigger_module.get_trigger(trigger) self._numerator_key = numerator_key self._denominator_key = denominator_key self._result_key = result_key self._numerator = 0 self._denominator = 0
def __init__( self, numerator_key: str, denominator_key: str, result_key: str, trigger: trigger_module.TriggerLike = (1, 'epoch'), ) -> None: self._trigger = trigger_module.get_trigger(trigger) self._numerator_key = numerator_key self._denominator_key = denominator_key self._result_key = result_key self._numerator = 0. self._denominator = 0.
def __init__(self, keys=None, trigger=(1, 'epoch'), postprocess=None, filename=None, **kwargs): self._keys = keys self._trigger = trigger_module.get_trigger(trigger) self._postprocess = postprocess self._log = [] # When using a writer, it needs to have a savefun defined # to deal with a string. self._writer = kwargs.get('writer', None) log_name = kwargs.get('log_name', 'log') if filename is None: filename = log_name del log_name # avoid accidental use self._log_name = filename self._init_summary()
def __init__( self, store_keys: Optional[Iterable[str]] = None, report_keys: Optional[Iterable[str]] = None, trigger: trigger_module.TriggerLike = (1, "epoch"), filename: Optional[str] = None, append: bool = False, format: Optional[str] = None, **kwargs: Any, ): self.time_summary = get_time_summary() # Initializes global TimeSummary. self.time_summary.initialize() if store_keys is None: self._store_keys = store_keys else: self._store_keys = list(store_keys) + [ key + ".std" for key in store_keys ] self._report_keys = report_keys self._trigger = trigger_module.get_trigger(trigger) self._log: List[Any] = [] log_name = kwargs.get("log_name", "log") if filename is None: filename = log_name del log_name # avoid accidental use self._log_name = filename self._writer = kwargs.get('writer', None) if format is None and filename is not None: if filename.endswith('.jsonl'): format = 'json-lines' elif filename.endswith('.yaml'): format = 'yaml' else: format = 'json' self._append = append self._format = format
def __init__( self, check_trigger: 'TriggerLike' = (1, 'epoch'), monitor: str = 'main/loss', patience: int = 3, mode: str = 'auto', verbose: bool = False, max_trigger: Tuple[int, 'UnitLiteral'] = (100, 'epoch'), ) -> None: self.count = 0 self.patience = patience self.monitor = monitor self.verbose = verbose self.already_warning = False self._max_trigger = trigger.IntervalTrigger(*max_trigger) self._interval_trigger = trigger.get_trigger(check_trigger) self._init_summary() if mode == 'max': self._compare = operator.gt elif mode == 'min': self._compare = operator.lt else: if 'accuracy' in monitor: self._compare = operator.gt else: self._compare = operator.lt if self._compare == operator.gt: if verbose: print('early stopping: operator is greater') self.best = float('-inf') else: if verbose: print('early stopping: operator is less') self.best = float('inf')
def __init__( self, store_keys=None, report_keys=None, trigger=(1, "epoch"), filename=None, append=False, format=None, **kwargs, ): # Initializes global TimeSummary. time_summary.initialize() if store_keys is None: self._store_keys = store_keys else: self._store_keys = list(store_keys) + [ key + ".std" for key in store_keys ] self._report_keys = report_keys self._trigger = trigger_module.get_trigger(trigger) self._log = [] log_name = kwargs.get("log_name", "log") if filename is None: filename = log_name del log_name # avoid accidental use self._log_name = filename self._writer = kwargs.get('writer', None) if format is None and filename is not None: if filename.endswith('.jsonl'): format = 'json-lines' elif filename.endswith('.yaml'): format = 'yaml' else: format = 'json' self._append = append self._format = format
def __init__( self, keys: Optional[Iterable[str]] = None, trigger: trigger_module.TriggerLike = (1, 'epoch'), postprocess: Optional[Callable[[Mapping[str, Any]], None]] = None, filename: Optional[str] = None, append: bool = False, format: Optional[str] = None, **kwargs: Any, ): self._keys = keys self._trigger = trigger_module.get_trigger(trigger) self._postprocess = postprocess self._log_buffer = _LogBuffer() self._log_looker = self._log_buffer.emit_new_looker() # When using a writer, it needs to have a savefun defined # to deal with a string. self._writer = kwargs.get('writer', None) if filename is None: filename = 'log' if format is None: if filename.endswith('.jsonl'): format = 'json-lines' elif filename.endswith('.yaml'): format = 'yaml' else: format = 'json' elif format not in ('json', 'json-lines', 'yaml'): raise ValueError(f'unsupported log format: {format}') self._filename = filename self._append = append self._format = format self._init_summary()
def __init__( self, models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]], optimizers: Union[torch.optim.Optimizer, Mapping[str, torch.optim.Optimizer]], max_epochs: int, extensions: Optional[Sequence['extension_module.ExtensionLike']], out_dir: str, writer: Optional[writing.Writer], stop_trigger: 'trigger_module.TriggerLike' = None, transform_model: _TransformModel = default_transform_model, enable_profile: bool = False, ) -> None: if extensions is None: extensions = [] if stop_trigger is None: self._stop_trigger = trigger_module.get_trigger( (max_epochs, 'epoch')) else: self._stop_trigger = trigger_module.get_trigger(stop_trigger) if writer is None: writer = writing.SimpleWriter(out_dir=out_dir) # triggers are stateful, so we need to make a copy for internal use self._internal_stop_trigger = copy.deepcopy(self._stop_trigger) self.observation: reporting.Observation = {} self._out = out_dir self.writer = writer self.reporter = reporting.Reporter() self._transform_model = transform_model self._start_extensions_called = False self._run_on_error_called = False # Indicates whether models can be accessed from extensions in the # current iteration. # The defualt value (True) indicates that it is allowed to access # models before starting a training loop. self._model_available = True if isinstance(models, collections.abc.Mapping): self._models = models else: if not isinstance(models, torch.nn.Module): raise ValueError( 'model must be an instance of dict or toch.nn.Module') self._models = {'main': models} if isinstance(optimizers, collections.abc.Mapping): self._optimizers = optimizers else: # TODO(ecastill) Optimizer type is not checked because of tests # using mocks and other classes self._optimizers = {'main': optimizers} for name, model in self._models.items(): # TODO we should not initialize extensions at this point # so, we cannot use `self.models` model = self._transform_model(name, model) self.reporter.add_observer(name, model) self.reporter.add_observers(name, model.named_modules()) self.max_epochs = max_epochs self._start_iteration = 0 # Defer! self._start_time: Optional[float] = None self.__iters_per_epoch: Optional[int] = None self._extensions: Dict[ str, extension_module.ExtensionEntry] = collections.OrderedDict() for ext in extensions: self.extend(ext) self._enable_profile = enable_profile # Initialize the writer self.writer.initialize(self.out)
def __init__(self, key, compare, trigger=(1, 'epoch')): self._key = key self._best_value = None self._interval_trigger = trigger_module.get_trigger(trigger) self._init_summary() self._compare = compare
def run(self, train_loader: Iterable[Any], val_loader: Optional[Iterable[Any]] = None, *, train_len: Optional[int] = None, eval_len: Optional[int] = None) -> None: """Executes the training loop. Args: train_loader (torch.utils.data.DataLoader): A data loader for training. val_loader (torch.utils.data.DataLoader, optional): A data loader passed to ``Evaluator.run()``. train_len (int, optional): The number of iterations per one training epoch. The default value is inferred from the size of training data loader. eval_len (int, optional): The number of iterations per one evaluation epoch, passed to ``Evaluator.run()`` .. seealso:: - :meth:`pytorch_pfn_extras.training._evaluator.Evaluator` """ if train_len is None: train_len = len(train_loader) # type: ignore[arg-type] if eval_len is None and val_loader is not None: eval_len = len(val_loader) # type: ignore[arg-type] self._train_len = train_len self._eval_len = eval_len class _EvaluatorExt: def __init__( self, trainer: 'Trainer', evaluator: 'Evaluator', val_loader: Optional[Iterable[Any]], eval_len: Optional[int], ) -> None: self.needs_model_state = True self._trainer = trainer self._evaluator = evaluator self._val_loader = val_loader self._eval_len = eval_len def __call__(self, manager: ExtensionsManagerProtocol) -> None: evaluator = self._evaluator if self._val_loader is None: raise ValueError('"val_loader" is not given.') evaluator.handler.train_validation_begin(self._trainer, evaluator) evaluator.run(self._val_loader, eval_len=self._eval_len) evaluator.handler.train_validation_end(self._trainer, evaluator) if self._manager is None: self._manager = self._setup_manager(train_len) for name, (evaluator, trigger) in self._evaluators.items(): # Register the evaluator as an extension to the manager # To be triggered with the correct timing self._manager.extend( _EvaluatorExt(self, evaluator, val_loader, eval_len), name=name, trigger=trigger_module.get_trigger(trigger), priority=extension.PRIORITY_WRITER, ) self.handler.train_setup(self, train_loader) if len(self._evaluators) == 0: if val_loader is not None: warnings.warn( '`val_loader` is given whereas the evaluator is missing.', UserWarning) else: if val_loader is None: raise ValueError('`val_loader` is required') for _, (evaluator, _) in self._evaluators.items(): evaluator.handler.eval_setup(evaluator, val_loader) with self._profile or _nullcontext() as prof: while not self.manager.stop_trigger: self.handler.train_epoch_begin(self, train_loader) # When iterations are completed in the callback # This is needed to avoid being constantly passing parameters self._idxs: 'queue.Queue[int]' = queue.Queue() self._inputs: 'queue.Queue[Any]' = queue.Queue() self._times: 'queue.Queue[float]' = queue.Queue() self._observed: 'queue.Queue[reporting.Observation]' = queue.Queue() # Iterator must be created after `train_epoch_begin` as it may be # using a DistributedSampler. loader_iter = iter(train_loader) self._profile_records: 'queue.Queue[List[_ReportNotification]]' \ = queue.Queue() for idx in range(train_len): with record( "pytorch_pfn_extras.training.Trainer:iteration", use_cuda=torch.cuda.is_available(), enable=self._enable_profile ) as ntf0: try: with record( "pytorch_pfn_extras.training.Trainer:get_data", enable=self._enable_profile ): x = next(loader_iter) except StopIteration: loader_iter = iter(train_loader) with record( "pytorch_pfn_extras.training.Trainer:get_data", enable=self._enable_profile ): x = next(loader_iter) begin = time.time() self._idxs.put(idx) self._inputs.put(x) self._times.put(begin) with record( "pytorch_pfn_extras.training.Trainer:run_iteration", use_cuda=torch.cuda.is_available(), enable=self._enable_profile ) as ntf1, \ self.manager.run_iteration(): self._observed.put(self.manager.observation) with record( "pytorch_pfn_extras.training.Trainer:train_step", use_cuda=torch.cuda.is_available(), enable=self._enable_profile ) as ntf2: self._profile_records.put([ntf0, ntf1, ntf2]) self.handler.train_step( self, idx, x, complete_fn=self._complete_step) # Check if the callback was called if prof is not None: prof.step() # type: ignore[no-untyped-call] # In some cases, DataLoaders are continuos # And will keep yielding results even if the epoch # is completed. We forcefully exit at the end of # every epoch if self.is_epoch_last_iter(idx) or self.manager.stop_trigger: break # In handlers that support a completely Async model train_epoch_end # Will take care of completing pending work self.handler.train_epoch_end(self) if prof is not None: prof.on_trace_ready = None self.handler.train_cleanup(self)
def extend(self, extension, name=None, trigger=None, priority=None, *, call_before_training=False, **kwargs): """Registers an extension to the manager. :class:`Extension` is a callable object which is called after each update unless the corresponding trigger object decides to skip the iteration. The order of execution is determined by priorities: extensions with higher priorities are called earlier in each iteration. Extensions with the same priority are invoked in the order of registrations. If two or more extensions with the same name are registered, suffixes are added to the names of the second to last extensions. The suffix is ``_N`` where N is the ordinal of the extensions. See :class:`Extension` for the interface of extensions. Args: extension: Extension to register. name (str): Name of the extension. If it is omitted, the :attr:`Extension.name` attribute of the extension is used or the :attr:`Extension.default_name` attribute of the extension if `name` is is set to `None` or is undefined. Note that the name would be suffixed by an ordinal in case of duplicated names as explained above. trigger (tuple or Trigger): Trigger object that determines when to invoke the extension. If it is ``None``, ``extension.trigger`` is used instead. If it is ``None`` and the extension does not have the trigger attribute, the extension is triggered at every iteration by default. If the trigger is not callable, it is passed to :class:`IntervalTrigger` to build an interval trigger. call_before_training (bool): Flag to call extension before training. Default is ``False``. priority (int): Invocation priority of the extension. Extensions are invoked in the descending order of priorities in each iteration. If this is ``None``, ``extension.priority`` is used instead. """ if name is None: name = getattr(extension, 'name', None) if name is None: name = getattr(extension, 'default_name', None) if name is None: name = getattr(extension, '__name__', None) if name is None: raise TypeError('name is not given for the extension') if name == 'training': raise ValueError( 'the name "training" is prohibited as an extension name') if trigger is None: trigger = getattr(extension, 'trigger', (1, 'iteration')) trigger = trigger_module.get_trigger(trigger) if priority is None: priority = getattr(extension, 'priority', extension_module.PRIORITY_READER) modified_name = name ordinal = 0 while modified_name in self._extensions: ordinal += 1 modified_name = '%s_%d' % (name, ordinal) extension.name = modified_name self._extensions[modified_name] = _ExtensionEntry( extension, priority, trigger, call_before_training)
def extend( self, extension: 'extension_module.ExtensionLike', name: Optional[str] = None, trigger: 'trigger_module.TriggerLike' = None, priority: Optional[int] = None, *, call_before_training: bool = False, **kwargs: Dict[str, Any], ) -> None: """Registers an extension to the manager. :class:`Extension` is a callable object which is called after each update unless the corresponding trigger object decides to skip the iteration. The order of execution is determined by priorities: extensions with higher priorities are called earlier in each iteration. Extensions with the same priority are invoked in the order of registrations. If two or more extensions with the same name are registered, suffixes are added to the names of the second to last extensions. The suffix is ``_N`` where N is the ordinal of the extensions. See :class:`Extension` for the interface of extensions. Args: extension: Extension to register. name (str): Name of the extension. If it is omitted, the :attr:`Extension.name` attribute of the extension is used or the :attr:`Extension.default_name` attribute of the extension if `name` is is set to `None` or is undefined. Note that the name would be suffixed by an ordinal in case of duplicated names as explained above. trigger (tuple or Trigger): Trigger object that determines when to invoke the extension. If it is ``None``, ``extension.trigger`` is used instead. If it is ``None`` and the extension does not have the trigger attribute, the extension is triggered at every iteration by default. If the trigger is not callable, it is passed to :class:`IntervalTrigger` to build an interval trigger. call_before_training (bool): Flag to call extension before training. Default is ``False``. priority (int): Invocation priority of the extension. Extensions are invoked in the descending order of priorities in each iteration. If this is ``None``, ``extension.priority`` is used instead. """ if self._start_extensions_called: raise RuntimeError( 'extend called after the extensions were initialized') ext = extension_module._as_extension(extension) if name is None: name = ext.name or ext.default_name if name == 'training': raise ValueError( 'the name "training" is prohibited as an extension name') if trigger is None: trigger = ext.trigger trigger = trigger_module.get_trigger(trigger) if priority is None: priority = ext.priority modified_name = name ordinal = 0 while modified_name in self._extensions: ordinal += 1 modified_name = '%s_%d' % (name, ordinal) ext.name = modified_name self._extensions[modified_name] = _ExtensionEntry( ext, priority, trigger, call_before_training)
def __init__( self, *, trigger: Optional[trigger_module.TriggerLike] = None, compare_fn: _CompareFn = _default_comparer, concurrency: Optional[int] = None, outputs: Union[bool, str, Sequence[str]] = True, params: Union[bool, str, Sequence[str]] = False, baseline: Optional[str] = None, ) -> None: """A class for comparison of iteration outputs and model parameters. This class is mainly used to compare results between different devices. Args: trigger (Trigger): Trigger object that determines when to compare values. compare_fn (function): Comparison function. Default is ``get_default_comparer()``. concurrency (int, optional): The upper bound limit on the number of workers that run concurrently. If ``None``, inferred from the size of ``engines``. outputs (tuple of str or bool): A set of keys of output dict to compare. params (tuple of str or bool): A set of keys of model parameters to compare. baseline (str, optional): The baseline engine that is assumed to be correct. Examples: >>> trainer_cpu = ppe.engine.create_trainer( model, optimizer, 1, device='cpu') >>> trainer_gpu = ppe.engine.create_trainer( model, optimizer, 1, device='cuda:0') >>> comp = ppe.utils.comparer.Comparer() >>> comp.add_engine("cpu", engine_cpu, train_1, eval_1) >>> comp.add_engine("gpu", engine_gpu, train_2, eval_2) >>> comp.compare() """ self._engine_type: Optional[Type[_Engine]] = None self._engines: Dict[str, Tuple[Union[_Engine, _LoadDumpsEngine], Any, Any]] = collections.OrderedDict() self._compare_fn = compare_fn self._targets: Dict[str, Dict[str, Any]] = {} self._output_keys = outputs self._param_keys = params self._baseline = baseline self._finalized = False self._concurrency = concurrency # Upper limit of semaphore size # Sempaphore for training step execution self._semaphore: Optional[threading.Semaphore] = None # Synchronizes iteration timing self._barrier: Optional[threading.Barrier] = None self._report_lock = threading.Lock() # Locks `Comparer._get_target` self._count = 0 if trigger is None: self._trigger = trigger_module.get_trigger((1, "epoch")) else: self._engine_type = _trainer.Trainer self._trigger = trigger_module.get_trigger(trigger)