def set_triggers(self): self.target_lr = self.current_state['target_lr'] self.update_trigger = trigger_module.get_trigger( self.current_state['update_trigger']) self.stop_trigger = trigger_module.get_trigger( self.current_state['stop_trigger']) self.phase_length, self.unit = self.current_state['stop_trigger']
def __init__(self, y_keys, x_key='iteration', trigger=(1, 'epoch'), postprocess=None, filename=None, marker='x', grid=True, **kwargs): file_name, = argument.parse_kwargs(kwargs, ('file_name', 'plot.png')) if filename is None: filename = file_name del file_name # avoid accidental use _check_available() self._x_key = x_key if isinstance(y_keys, str): y_keys = (y_keys, ) self._y_keys = y_keys self._trigger = trigger_module.get_trigger(trigger) self._file_name = filename self._marker = marker self._grid = grid self._postprocess = postprocess self._init_summary() self._data = {k: [] for k in y_keys}
def __init__(self, y_keys, x_key='iteration', trigger=(1, 'epoch'), postprocess=None, file_name='plot.png', marker='x', grid=True): _check_available() if not _available: return self._x_key = x_key if isinstance(y_keys, str): y_keys = (y_keys, ) self._y_keys = y_keys self._trigger = trigger_module.get_trigger(trigger) self._file_name = file_name self._marker = marker self._grid = grid self._postprocess = postprocess self._init_summary() self._data = {k: [] for k in y_keys}
def __init__(self, updater, stop_trigger=None, out='result', extensions=None): self.updater = updater self.stop_trigger = trigger_module.get_trigger(stop_trigger) self.observation = {} self.out = out if extensions is None: extensions = [] reporter = reporter_module.Reporter() for name, optimizer in six.iteritems(updater.get_all_optimizers()): reporter.add_observer(name, optimizer.target) reporter.add_observers( name, optimizer.target.namedlinks(skipself=True)) self.reporter = reporter self._done = False self._extensions = collections.OrderedDict() self._start_at = None self._snapshot_elapsed_time = 0.0 self._final_elapsed_time = None updater.connect_trainer(self) for ext in extensions: self.extend(ext)
def __init__(self, dataset_specification, dataset_class, blank_label, trigger=(1, 'epoch'), min_delta=0.01, attributes_to_adjust=(), maxlen=5, dataset_args=None): if dataset_args is None: dataset_args = {} self.dataset_class = dataset_class self.dataset_args = dataset_args self.trigger = trigger_module.get_trigger(trigger) self.maxlen = maxlen self.queue = deque(maxlen=self.maxlen) self.min_delta = min_delta self.attributes_to_adjust = attributes_to_adjust self.blank_label = blank_label self.force_enlarge_dataset = False with open(dataset_specification) as specification: print(specification) specification = json.load(specification) print(specification) self.train_curriculum = { i: s['train'] for i, s in enumerate(specification) } self.validation_curriculum = { i: s['validation'] for i, s in enumerate(specification) } self.current_level = 0
def __init__(self, links, statistics=default_statistics, report_params=True, report_grads=True, prefix=None, trigger=(1, 'epoch'), skip_nan_params=False): if not isinstance(links, (list, tuple)): links = links, self._links = links if statistics is None: statistics = {} self._statistics = statistics attrs = [] if report_params: attrs.append('data') if report_grads: attrs.append('grad') self._attrs = attrs self._prefix = prefix self._trigger = trigger_module.get_trigger(trigger) self._summary = reporter.DictSummary() self._skip_nan_params = skip_nan_params
def __init__(self, targets, max_sample_size=1000, report_data=True, report_grad=True, plot_mean=True, plot_std=True, percentile_sigmas=(0, 0.13, 2.28, 15.87, 50, 84.13, 97.72, 99.87, 100), trigger=(1, 'epoch'), file_name='statistics.png', figsize=None, marker=None, grid=True): if file_name is None: raise ValueError('Missing output file name of statstics plot') self._vars = _unpack_variables(targets) if len(self._vars) == 0: raise ValueError( 'Need at least one variables for which to collect statistics.' '\nActual: 0 <= 0') if not any((plot_mean, plot_std, bool(percentile_sigmas))): raise ValueError('Nothing to plot') self._keys = [] if report_data: self._keys.append('data') if report_grad: self._keys.append('grad') self._report_data = report_data self._report_grad = report_grad self._statistician = Statistician(collect_mean=plot_mean, collect_std=plot_std, percentile_sigmas=percentile_sigmas) self._plot_mean = plot_mean self._plot_std = plot_std self._plot_percentile = bool(percentile_sigmas) self._trigger = trigger_module.get_trigger(trigger) self._file_name = file_name self._figsize = figsize self._marker = marker self._grid = grid if not self._plot_percentile: n_percentile = 0 else: if not isinstance(percentile_sigmas, (list, tuple)): n_percentile = 1 # scalar, single percentile else: n_percentile = len(percentile_sigmas) self._data_shape = (len(self._keys), int(plot_mean) + int(plot_std) + n_percentile) self._samples = Reservoir(max_sample_size, data_shape=self._data_shape)
def __init__(self, trigger=(1, 'iteration'), receivers={}, file_name='commands'): self._trigger = trigger_module.get_trigger(trigger) self._file_name = file_name self._receivers = self.default_receivers.copy() self._receivers.update(receivers)
def __init__(self, updater, stop_trigger=None, out='result', extensions=None): self.updater = updater self.stop_trigger = trigger_module.get_trigger(stop_trigger) self.observation = {} self.out = out if extensions is None: extensions = [] reporter = reporter_module.Reporter() for name, optimizer in six.iteritems(updater.get_all_optimizers()): reporter.add_observer(name, optimizer.target) reporter.add_observers(name, optimizer.target.namedlinks(skipself=True)) self.reporter = reporter self._done = False self._extensions = collections.OrderedDict() self._start_at = None self._snapshot_elapsed_time = 0.0 self._final_elapsed_time = None updater.connect_trainer(self) for ext in extensions: self.extend(ext)
def __init__(self, shift, attr='lr', trigger=(1, 'iteration'), min_delta=0.1): self.shift = shift self.attr = attr self.trigger = trigger_module.get_trigger(trigger) self.queue = deque(maxlen=5) self.min_delta = min_delta self.force_shift = False
def __init__(self, targets, max_sample_size=1000, report_data=True, report_grad=True, plot_mean=True, plot_std=True, percentile_sigmas=( 0, 0.13, 2.28, 15.87, 50, 84.13, 97.72, 99.87, 100), trigger=(1, 'epoch'), filename='statistics.png', figsize=None, marker=None, grid=True, **kwargs): file_name, = argument.parse_kwargs( kwargs, ('file_name', 'statistics.png') ) if filename is None: filename = file_name if filename is None: raise ValueError('Missing output file name of statstics plot') self._vars = _unpack_variables(targets) if len(self._vars) == 0: raise ValueError( 'Need at least one variables for which to collect statistics.' '\nActual: 0 <= 0') if not any((plot_mean, plot_std, bool(percentile_sigmas))): raise ValueError('Nothing to plot') self._keys = [] if report_data: self._keys.append('data') if report_grad: self._keys.append('grad') self._report_data = report_data self._report_grad = report_grad self._statistician = Statistician( collect_mean=plot_mean, collect_std=plot_std, percentile_sigmas=percentile_sigmas) self._plot_mean = plot_mean self._plot_std = plot_std self._plot_percentile = bool(percentile_sigmas) self._trigger = trigger_module.get_trigger(trigger) self._filename = filename self._figsize = figsize self._marker = marker self._grid = grid if not self._plot_percentile: n_percentile = 0 else: if not isinstance(percentile_sigmas, (list, tuple)): n_percentile = 1 # scalar, single percentile else: n_percentile = len(percentile_sigmas) self._data_shape = ( len(self._keys), int(plot_mean) + int(plot_std) + n_percentile) self._samples = Reservoir(max_sample_size, data_shape=self._data_shape)
def __init__(self, keys=None, trigger=(1, 'epoch'), postprocess=None, log_name='log'): self._keys = keys self._trigger = trigger_module.get_trigger(trigger) self._postprocess = postprocess self._log_name = log_name self._log = [] self._init_summary()
def __init__(self, detector, image_paths, names, size, thresh=0.6, trigger=(1, 'epoch'), device=-1): self._detector = detector self._image_paths = image_paths self._names = names self._size = size self._thresh = thresh self._trigger = trigger_module.get_trigger(trigger) self._device = device
def __init__(self, keys=None, trigger=(1, 'epoch'), postprocess=None, log_name='log'): self._keys = keys self._trigger = trigger_module.get_trigger(trigger) self._postprocess = postprocess self._log_name = log_name self._log = [] self._init_summary()
def __init__(self, notifier, custom_metrics=None): # TODO: additional metrics calculation inside training loop if custom_metrics is not None: self.custom_metrics = custom_metrics if notifier is not None: self.notifier = notifier else: raise ValueError('Notifier is None') # notifier specific attributes self.details = None # dict of availble details: epochs, batch_size, lr, etc self.starting_time = None self.current_epoch = 1 # chainer specific attributes self._trigger_epoch = trigger_module.get_trigger((1, 'epoch')) self._trigger_iteration = trigger_module.get_trigger((1, 'iteration'))
def __init__(self, dataset, crop_sizes, max_size_iteration=None, trigger=(80, 'iteration')): super(CropSizeUpdater, self).__init__() self.dataset = dataset self.crop_sizes = crop_sizes self.max_size_iteration = max_size_iteration self._trigger = trigger_module.get_trigger(trigger)
def __init__(self, keys=None, trigger=(1, 'iteration'), log_json_name='log', log_csv_name='log.csv'): self._keys = keys self._trigger = trigger_module.get_trigger(trigger) self._log_csv_name = log_csv_name self._log_json_name = log_json_name self._log = [] self._init_summary() self._is_initialized = False
def __init__(self, watch_items, trigger=(100, 'iteration'), interval=5.0): self._trigger = trigger_module.get_trigger(trigger) self._interval = interval self._watch_items = [] for item in watch_items: if not isinstance(item, WatchItem): action, estimator = item item = WatchItem(action=action, estimator=estimator) self._watch_items.append(item) self._heartbeat_queue = Queue() self._stop_event = Event() self._heartbeat_thread = Process(target=self._heartbeat_handler) self._heartbeat_thread.daemon = True
def __init__(self, y_keys, x_key='iteration', trigger=(1, 'epoch'), postprocess=None, file_name='graph.png'): _check_available() self._x_key = x_key if isinstance(y_keys, str): y_keys = (y_keys,) self._y_keys = y_keys self._trigger = trigger_module.get_trigger(trigger) self._file_name = file_name self._postprocess = postprocess self._init_summary() self._data = {k: [] for k in y_keys}
def __init__(self, keys=None, trigger=(1, 'epoch'), postprocess=None, filename=None, **kwargs): self._keys = keys self._trigger = trigger_module.get_trigger(trigger) self._postprocess = postprocess self._log = [] log_name, = argument.parse_kwargs( kwargs, ('log_name', 'log'), ) if filename is None: filename = log_name self._log_name = filename self._init_summary()
def __init__(self, target, keys, s_keys=[], mname=None, log_report='LogReport', trigger=(1, 'epoch')): if isinstance(target, link.Link): target = {'main': target} self._targets = target self._keys = keys self._s_keys = s_keys self._mname = mname self._log_report = log_report self.trigger = trigger_module.get_trigger(trigger)
def __init__( self, writer, keys=None, trigger=(1, "epoch"), postprocess=None, log_name="log.json", ): self._writer = writer self._keys = keys self._trigger = trigger_module.get_trigger(trigger) self._postprocess = postprocess self._log_name = log_name self._log = [] self._init_summary()
def __init__(self, updater, stop_trigger=None, out='result'): self.updater = updater self.stop_trigger = trigger_module.get_trigger(stop_trigger) self.observation = {} self.out = out reporter = reporter_module.Reporter() for name, optimizer in six.iteritems(updater.get_all_optimizers()): reporter.add_observer(name, optimizer.target) reporter.add_observers( name, optimizer.target.namedlinks(skipself=True)) self.reporter = reporter self._done = False self._extensions = collections.OrderedDict() updater.connect_trainer(self)
def __init__(self, updater, stop_trigger=None, out='result'): self.updater = updater self.stop_trigger = trigger_module.get_trigger(stop_trigger) self.observation = {} self.out = out reporter = reporter_module.Reporter() for name, optimizer in six.iteritems(updater.get_all_optimizers()): reporter.add_observer(name, optimizer.target) reporter.add_observers(name, optimizer.target.namedlinks(skipself=True)) self.reporter = reporter self._done = False self._extensions = collections.OrderedDict() updater.connect_trainer(self)
def __init__(self, y_keys, x_key='iteration', trigger=(1, 'epoch'), log_dir=None): if not os.path.isdir(log_dir): os.makedirs(log_dir) self._log_dir = log_dir self._writer = SummaryWriter(log_dir) self._x_key = x_key if isinstance(y_keys, str): y_keys = (y_keys, ) self._y_keys = y_keys self._trigger = trigger_module.get_trigger(trigger) self._init_summary() self._data = {k: [] for k in y_keys}
def __init__(self, links, statistics=default_statistics, report_params=True, report_grads=True, prefix=None, trigger=(1, 'epoch')): if not isinstance(links, (list, tuple)): links = links, self._links = links self._statistics = statistics attrs = [] if report_params: attrs.append('data') if report_grads: attrs.append('grad') self._attrs = attrs self._prefix = prefix self._trigger = trigger_module.get_trigger(trigger) self._summary = reporter.DictSummary()
def __init__(self, trigger=(10000, 'iteration'), postprocess=None, segmentation_loss_key='main/loss/mask', detection_loss_loc_key='main/loss/loc', detection_loss_conf_key='main/loss/conf', smooth_alpha=0.85, split_alpha=0.15): # conduct the action of loss division self._trigger = trigger_module.get_trigger(trigger) self.alpha = smooth_alpha self.split_alpha = split_alpha self._postprocess = postprocess self._segmentation_loss_key = segmentation_loss_key self._detection_loss_conf_key = detection_loss_conf_key self._detection_loss_loc_key = detection_loss_loc_key self._max_loss_seg = None self._current_loss_seg = None self._max_loss_det_loc = None self._current_loss_det_loc = None self._max_loss_det_conf = None self._current_loss_det_conf = None self._current_loss_split = None
def __init__(self, y_keys, x_key='iteration', trigger=(1, 'epoch'), postprocess=None, file_name='plot.png', marker='x', grid=True): _check_available() if not _available: return self._x_key = x_key if isinstance(y_keys, str): y_keys = (y_keys,) self._y_keys = y_keys self._trigger = trigger_module.get_trigger(trigger) self._file_name = file_name self._marker = marker self._grid = grid self._postprocess = postprocess self._init_summary() self._data = {k: [] for k in y_keys}
def __init__(self, y_keys, x_key='iteration', trigger=(1, 'epoch'), postprocess=None, filename=None, marker='x', grid=True, **kwargs): file_name, = argument.parse_kwargs(kwargs, ('file_name', 'plot.png')) if filename is None: filename = file_name _check_available() self._x_key = x_key if isinstance(y_keys, str): y_keys = (y_keys,) self._y_keys = y_keys self._trigger = trigger_module.get_trigger(trigger) self._file_name = file_name self._marker = marker self._grid = grid self._postprocess = postprocess self._init_summary() self._data = {k: [] for k in y_keys}
def extend(self, extension, name=None, trigger=None, priority=None, **kwargs): argument.check_unexpected_kwargs( kwargs, invoke_before_training='invoke_before_training has been removed ' 'since Chainer v2.0.0. Use initializer= instead.') argument.assert_kwargs_empty(kwargs) if name is None: name = getattr(extension, 'name', None) if name is None: name = getattr(extension, 'default_name', None) if name is None: name = getattr(extension, '__name__', None) if name is None: raise TypeError('name is not given for the extension') if name == 'training': raise ValueError( 'the name "training" is prohibited as an extension name') if trigger is None: trigger = getattr(extension, 'trigger', (1, 'iteration')) trigger = trigger_module.get_trigger(trigger) if priority is None: priority = getattr( extension, 'priority', extension_module.PRIORITY_READER) modified_name = name ordinal = 0 while modified_name in self._extensions: ordinal += 1 modified_name = '%s_%d' % (name, ordinal) extension.name = modified_name self._extensions[modified_name] = _ExtensionEntry( extension, priority, trigger)
def extend(self, extension, name=None, trigger=None, priority=None, **kwargs): """Registers an extension to the trainer. :class:`Extension` is a callable object which is called after each update unless the corresponding trigger object decides to skip the iteration. The order of execution is determined by priorities: extensions with higher priorities are called earlier in each iteration. Extensions with the same priority are invoked in the order of registrations. If two or more extensions with the same name are registered, suffixes are added to the names of the second to last extensions. The suffix is ``_N`` where N is the ordinal of the extensions. See :class:`Extension` for the interface of extensions. Args: extension: Extension to register. name (str): Name of the extension. If it is omitted, the :attr:`Extension.name` attribute of the extension is used or the :attr:`Extension.default_name` attribute of the extension if `name` is is set to `None` or is undefined. Note that the name would be suffixed by an ordinal in case of duplicated names as explained above. trigger (tuple or Trigger): Trigger object that determines when to invoke the extension. If it is ``None``, ``extension.trigger`` is used instead. If it is ``None`` and the extension does not have the trigger attribute, the extension is triggered at every iteration by default. If the trigger is not callable, it is passed to :class:`IntervalTrigger` to build an interval trigger. priority (int): Invocation priority of the extension. Extensions are invoked in the descending order of priorities in each iteration. If this is ``None``, ``extension.priority`` is used instead. """ if kwargs: argument.check_unexpected_kwargs( kwargs, invoke_before_training='invoke_before_training has been ' 'removed since Chainer v2.0.0. Use initializer= instead.') argument.assert_kwargs_empty(kwargs) if name is None: name = getattr(extension, 'name', None) if name is None: name = getattr(extension, 'default_name', None) if name is None: name = getattr(extension, '__name__', None) if name is None: raise TypeError('name is not given for the extension') if name == 'training': raise ValueError( 'the name "training" is prohibited as an extension name') if trigger is None: trigger = getattr(extension, 'trigger', (1, 'iteration')) trigger = trigger_module.get_trigger(trigger) if priority is None: priority = getattr(extension, 'priority', extension_module.PRIORITY_READER) modified_name = name ordinal = 0 while modified_name in self._extensions: ordinal += 1 modified_name = '%s_%d' % (name, ordinal) extension.name = modified_name self._extensions[modified_name] = _ExtensionEntry( extension, priority, trigger)
def __init__(self, key, compare, trigger=(1, 'epoch')): self._key = key self._best_value = None self._interval_trigger = trigger_module.get_trigger(trigger) self._init_summary() self._compare = compare
def __init__(self, trigger=(1, 'epoch'), image_generator=None): _check_available() self._trigger = trigger_module.get_trigger(trigger) self._fn = image_generator self._info_name = '.chainerui_images' self._infos = []
def __init__(self, experiment, model, function, trigger=(1, 'epoch')): self.experiment = experiment self.model = model self.function = function self._trigger = trigger_module.get_trigger(trigger)
def extend(self, extension, name=None, trigger=None, priority=None, invoke_before_training=None): """Registers an extension to the trainer. :class:`Extension` is a callable object which is called after each update unless the corresponding trigger object decides to skip the iteration. The order of execution is determined by priorities: extensions with higher priorities are called earlier in each iteration. Extensions with the same priority are invoked in the order of registrations. If two or more extensions with the same name are registered, suffixes are added to the names of the second to last extensions. The suffix is ``_N`` where N is the ordinal of the extensions. See :class:`Extension` for the interface of extensions. Args: extension: Extension to register. name (str): Name of the extension. If it is omitted, the ``default_name`` attribute of the extension is used instead. Note that the name would be suffixed by an ordinal in case of duplicated names as explained above. trigger (tuple or Trigger): Trigger object that determines when to invoke the extension. If it is ``None``, ``extension.trigger`` is used instead. If the trigger is not callable, it is passed to :class:`IntervalTrigger` to build an interval trigger. priority (int): Invocation priority of the extension. Extensions are invoked in the descending order of priorities in each iteration. If this is ``None``, ``extension.priority`` is used instead. invoke_before_training (bool or None): If ``True``, the extension is also invoked just before entering the training loop. If this is ``None``, ``extension.invoke_before_training`` is used instead. This option is mainly used for extensions that alter the training configuration (e.g., learning rates); in such a case, resuming from snapshots require the call of extension to recover the configuration before any updates. """ if name is None: name = getattr(extension, 'name', None) if name is None: name = getattr(extension, 'default_name', None) if name is None: raise TypeError('name is not given for the extension') if name == 'training': raise ValueError( 'the name "training" is prohibited as an extension name') if trigger is None: trigger = getattr(extension, 'trigger', None) trigger = trigger_module.get_trigger(trigger) if priority is None: priority = getattr( extension, 'priority', extension_module.PRIORITY_READER) if invoke_before_training is None: invoke_before_training = getattr( extension, 'invoke_before_training', False) modified_name = name ordinal = 0 while modified_name in self._extensions: ordinal += 1 modified_name = '%s_%d' % (name, ordinal) extension.name = modified_name self._extensions[modified_name] = _ExtensionEntry( extension, priority, trigger, invoke_before_training)
def pre_run(self, data, max_epochs=1): # Method interception to capture the max_epochs # max_epochs is never saved in the Engine class self.max_epochs = max_epochs self.stop_trigger = trigger_module.get_trigger((max_epochs, 'epoch')) Engine.run(self.engine, data, max_epochs)
def __init__(self, shift, attr='lr', trigger=(1, 'epoch')): self.shift = shift self.attr = attr self.trigger = trigger_module.get_trigger(trigger)
def extend(self, extension, name=None, trigger=None, priority=None, invoke_before_training=None): """Registers an extension to the trainer. :class:`Extension` is a callable object which is called after each update unless the corresponding trigger object decides to skip the iteration. The order of execution is determined by priorities: extensions with higher priorities are called earlier in each iteration. Extensions with the same priority are invoked in the order of registrations. If two or more extensions with the same name are registered, suffixes are added to the names of the second to last extensions. The suffix is ``_N`` where N is the ordinal of the extensions. See :class:`Extension` for the interface of extensions. Args: extension: Extension to register. name (str): Name of the extension. If it is omitted, the ``default_name`` attribute of the extension is used instead. Note that the name would be suffixed by an ordinal in case of duplicated names as explained above. trigger (tuple or Trigger): Trigger object that determines when to invoke the extension. If it is ``None``, ``extension.trigger`` is used instead. If it is ``None`` and the extension does not have the trigger attribute, the extension is triggered at every iteration by default. If the trigger is not callable, it is passed to :class:`IntervalTrigger` to build an interval trigger. priority (int): Invocation priority of the extension. Extensions are invoked in the descending order of priorities in each iteration. If this is ``None``, ``extension.priority`` is used instead. invoke_before_training (bool or None): If ``True``, the extension is also invoked just before entering the training loop. If this is ``None``, ``extension.invoke_before_training`` is used instead. This option is mainly used for extensions that alter the training configuration (e.g., learning rates); in such a case, resuming from snapshots require the call of extension to recover the configuration before any updates. """ if name is None: name = getattr(extension, 'name', None) if name is None: name = getattr(extension, 'default_name', None) if name is None: name = getattr(extension, '__name__', None) if name is None: raise TypeError('name is not given for the extension') if name == 'training': raise ValueError( 'the name "training" is prohibited as an extension name') if trigger is None: trigger = getattr(extension, 'trigger', (1, 'iteration')) trigger = trigger_module.get_trigger(trigger) if priority is None: priority = getattr(extension, 'priority', extension_module.PRIORITY_READER) if invoke_before_training is None: invoke_before_training = getattr(extension, 'invoke_before_training', False) modified_name = name ordinal = 0 while modified_name in self._extensions: ordinal += 1 modified_name = '%s_%d' % (name, ordinal) extension.name = modified_name self._extensions[modified_name] = _ExtensionEntry( extension, priority, trigger, invoke_before_training)
def __init__(self, key, compare, trigger=(1, 'epoch')): self._key = key self._best_value = None self._interval_trigger = trigger_module.get_trigger(trigger) self._init_summary() self._compare = compare
def extend(self, extension, name=None, trigger=None, priority=None, **kwargs): """Registers an extension to the trainer. :class:`Extension` is a callable object which is called after each update unless the corresponding trigger object decides to skip the iteration. The order of execution is determined by priorities: extensions with higher priorities are called earlier in each iteration. Extensions with the same priority are invoked in the order of registrations. If two or more extensions with the same name are registered, suffixes are added to the names of the second to last extensions. The suffix is ``_N`` where N is the ordinal of the extensions. See :class:`Extension` for the interface of extensions. Args: extension: Extension to register. name (str): Name of the extension. If it is omitted, the ``default_name`` attribute of the extension is used instead. Note that the name would be suffixed by an ordinal in case of duplicated names as explained above. trigger (tuple or Trigger): Trigger object that determines when to invoke the extension. If it is ``None``, ``extension.trigger`` is used instead. If it is ``None`` and the extension does not have the trigger attribute, the extension is triggered at every iteration by default. If the trigger is not callable, it is passed to :class:`IntervalTrigger` to build an interval trigger. priority (int): Invocation priority of the extension. Extensions are invoked in the descending order of priorities in each iteration. If this is ``None``, ``extension.priority`` is used instead. """ argument.check_unexpected_kwargs( kwargs, invoke_before_training='invoke_before_training has been removed ' 'since Chainer v2.0.0. Use initializer= instead.') argument.assert_kwargs_empty(kwargs) if name is None: name = getattr(extension, 'name', None) if name is None: name = getattr(extension, 'default_name', None) if name is None: name = getattr(extension, '__name__', None) if name is None: raise TypeError('name is not given for the extension') if name == 'training': raise ValueError( 'the name "training" is prohibited as an extension name') if trigger is None: trigger = getattr(extension, 'trigger', (1, 'iteration')) trigger = trigger_module.get_trigger(trigger) if priority is None: priority = getattr( extension, 'priority', extension_module.PRIORITY_READER) modified_name = name ordinal = 0 while modified_name in self._extensions: ordinal += 1 modified_name = '%s_%d' % (name, ordinal) extension.name = modified_name self._extensions[modified_name] = _ExtensionEntry( extension, priority, trigger)