예제 #1
0
 def set_triggers(self):
     self.target_lr = self.current_state['target_lr']
     self.update_trigger = trigger_module.get_trigger(
         self.current_state['update_trigger'])
     self.stop_trigger = trigger_module.get_trigger(
         self.current_state['stop_trigger'])
     self.phase_length, self.unit = self.current_state['stop_trigger']
예제 #2
0
    def __init__(self,
                 y_keys,
                 x_key='iteration',
                 trigger=(1, 'epoch'),
                 postprocess=None,
                 filename=None,
                 marker='x',
                 grid=True,
                 **kwargs):

        file_name, = argument.parse_kwargs(kwargs, ('file_name', 'plot.png'))
        if filename is None:
            filename = file_name
        del file_name  # avoid accidental use

        _check_available()

        self._x_key = x_key
        if isinstance(y_keys, str):
            y_keys = (y_keys, )

        self._y_keys = y_keys
        self._trigger = trigger_module.get_trigger(trigger)
        self._file_name = filename
        self._marker = marker
        self._grid = grid
        self._postprocess = postprocess
        self._init_summary()
        self._data = {k: [] for k in y_keys}
예제 #3
0
파일: plot_report.py 프로젝트: yygr/chainer
    def __init__(self,
                 y_keys,
                 x_key='iteration',
                 trigger=(1, 'epoch'),
                 postprocess=None,
                 file_name='plot.png',
                 marker='x',
                 grid=True):

        _check_available()

        if not _available:
            return

        self._x_key = x_key
        if isinstance(y_keys, str):
            y_keys = (y_keys, )

        self._y_keys = y_keys
        self._trigger = trigger_module.get_trigger(trigger)
        self._file_name = file_name
        self._marker = marker
        self._grid = grid
        self._postprocess = postprocess
        self._init_summary()
        self._data = {k: [] for k in y_keys}
예제 #4
0
파일: trainer.py 프로젝트: tkng/chainer
    def __init__(self, updater, stop_trigger=None, out='result',
                 extensions=None):
        self.updater = updater
        self.stop_trigger = trigger_module.get_trigger(stop_trigger)
        self.observation = {}
        self.out = out
        if extensions is None:
            extensions = []

        reporter = reporter_module.Reporter()
        for name, optimizer in six.iteritems(updater.get_all_optimizers()):
            reporter.add_observer(name, optimizer.target)
            reporter.add_observers(
                name, optimizer.target.namedlinks(skipself=True))
        self.reporter = reporter

        self._done = False
        self._extensions = collections.OrderedDict()

        self._start_at = None
        self._snapshot_elapsed_time = 0.0
        self._final_elapsed_time = None

        updater.connect_trainer(self)
        for ext in extensions:
            self.extend(ext)
예제 #5
0
    def __init__(self,
                 dataset_specification,
                 dataset_class,
                 blank_label,
                 trigger=(1, 'epoch'),
                 min_delta=0.01,
                 attributes_to_adjust=(),
                 maxlen=5,
                 dataset_args=None):
        if dataset_args is None:
            dataset_args = {}
        self.dataset_class = dataset_class
        self.dataset_args = dataset_args
        self.trigger = trigger_module.get_trigger(trigger)
        self.maxlen = maxlen
        self.queue = deque(maxlen=self.maxlen)
        self.min_delta = min_delta
        self.attributes_to_adjust = attributes_to_adjust
        self.blank_label = blank_label
        self.force_enlarge_dataset = False

        with open(dataset_specification) as specification:
            print(specification)
            specification = json.load(specification)
            print(specification)
            self.train_curriculum = {
                i: s['train']
                for i, s in enumerate(specification)
            }
            self.validation_curriculum = {
                i: s['validation']
                for i, s in enumerate(specification)
            }

        self.current_level = 0
예제 #6
0
    def __init__(self,
                 links,
                 statistics=default_statistics,
                 report_params=True,
                 report_grads=True,
                 prefix=None,
                 trigger=(1, 'epoch'),
                 skip_nan_params=False):

        if not isinstance(links, (list, tuple)):
            links = links,
        self._links = links

        if statistics is None:
            statistics = {}
        self._statistics = statistics

        attrs = []
        if report_params:
            attrs.append('data')
        if report_grads:
            attrs.append('grad')
        self._attrs = attrs

        self._prefix = prefix
        self._trigger = trigger_module.get_trigger(trigger)
        self._summary = reporter.DictSummary()
        self._skip_nan_params = skip_nan_params
예제 #7
0
    def __init__(self,
                 targets,
                 max_sample_size=1000,
                 report_data=True,
                 report_grad=True,
                 plot_mean=True,
                 plot_std=True,
                 percentile_sigmas=(0, 0.13, 2.28, 15.87, 50, 84.13, 97.72,
                                    99.87, 100),
                 trigger=(1, 'epoch'),
                 file_name='statistics.png',
                 figsize=None,
                 marker=None,
                 grid=True):

        if file_name is None:
            raise ValueError('Missing output file name of statstics plot')

        self._vars = _unpack_variables(targets)
        if len(self._vars) == 0:
            raise ValueError(
                'Need at least one variables for which to collect statistics.'
                '\nActual: 0 <= 0')

        if not any((plot_mean, plot_std, bool(percentile_sigmas))):
            raise ValueError('Nothing to plot')

        self._keys = []
        if report_data:
            self._keys.append('data')
        if report_grad:
            self._keys.append('grad')

        self._report_data = report_data
        self._report_grad = report_grad

        self._statistician = Statistician(collect_mean=plot_mean,
                                          collect_std=plot_std,
                                          percentile_sigmas=percentile_sigmas)

        self._plot_mean = plot_mean
        self._plot_std = plot_std
        self._plot_percentile = bool(percentile_sigmas)

        self._trigger = trigger_module.get_trigger(trigger)
        self._file_name = file_name
        self._figsize = figsize
        self._marker = marker
        self._grid = grid

        if not self._plot_percentile:
            n_percentile = 0
        else:
            if not isinstance(percentile_sigmas, (list, tuple)):
                n_percentile = 1  # scalar, single percentile
            else:
                n_percentile = len(percentile_sigmas)
        self._data_shape = (len(self._keys),
                            int(plot_mean) + int(plot_std) + n_percentile)
        self._samples = Reservoir(max_sample_size, data_shape=self._data_shape)
예제 #8
0
    def __init__(self, trigger=(1, 'iteration'), receivers={},
                 file_name='commands'):
        self._trigger = trigger_module.get_trigger(trigger)
        self._file_name = file_name

        self._receivers = self.default_receivers.copy()
        self._receivers.update(receivers)
예제 #9
0
    def __init__(self,
                 updater,
                 stop_trigger=None,
                 out='result',
                 extensions=None):
        self.updater = updater
        self.stop_trigger = trigger_module.get_trigger(stop_trigger)
        self.observation = {}
        self.out = out
        if extensions is None:
            extensions = []

        reporter = reporter_module.Reporter()
        for name, optimizer in six.iteritems(updater.get_all_optimizers()):
            reporter.add_observer(name, optimizer.target)
            reporter.add_observers(name,
                                   optimizer.target.namedlinks(skipself=True))
        self.reporter = reporter

        self._done = False
        self._extensions = collections.OrderedDict()

        self._start_at = None
        self._snapshot_elapsed_time = 0.0
        self._final_elapsed_time = None

        updater.connect_trainer(self)
        for ext in extensions:
            self.extend(ext)
 def __init__(self, shift, attr='lr', trigger=(1, 'iteration'), min_delta=0.1):
     self.shift = shift
     self.attr = attr
     self.trigger = trigger_module.get_trigger(trigger)
     self.queue = deque(maxlen=5)
     self.min_delta = min_delta
     self.force_shift = False
예제 #11
0
    def __init__(self, targets, max_sample_size=1000,
                 report_data=True, report_grad=True,
                 plot_mean=True, plot_std=True,
                 percentile_sigmas=(
                     0, 0.13, 2.28, 15.87, 50, 84.13, 97.72, 99.87, 100),
                 trigger=(1, 'epoch'), filename='statistics.png',
                 figsize=None, marker=None, grid=True, **kwargs):

        file_name, = argument.parse_kwargs(
            kwargs, ('file_name', 'statistics.png')
        )
        if filename is None:
            filename = file_name

        if filename is None:
            raise ValueError('Missing output file name of statstics plot')

        self._vars = _unpack_variables(targets)
        if len(self._vars) == 0:
            raise ValueError(
                'Need at least one variables for which to collect statistics.'
                '\nActual: 0 <= 0')

        if not any((plot_mean, plot_std, bool(percentile_sigmas))):
            raise ValueError('Nothing to plot')

        self._keys = []
        if report_data:
            self._keys.append('data')
        if report_grad:
            self._keys.append('grad')

        self._report_data = report_data
        self._report_grad = report_grad

        self._statistician = Statistician(
            collect_mean=plot_mean, collect_std=plot_std,
            percentile_sigmas=percentile_sigmas)

        self._plot_mean = plot_mean
        self._plot_std = plot_std
        self._plot_percentile = bool(percentile_sigmas)

        self._trigger = trigger_module.get_trigger(trigger)
        self._filename = filename
        self._figsize = figsize
        self._marker = marker
        self._grid = grid

        if not self._plot_percentile:
            n_percentile = 0
        else:
            if not isinstance(percentile_sigmas, (list, tuple)):
                n_percentile = 1  # scalar, single percentile
            else:
                n_percentile = len(percentile_sigmas)
        self._data_shape = (
            len(self._keys), int(plot_mean) + int(plot_std) + n_percentile)
        self._samples = Reservoir(max_sample_size, data_shape=self._data_shape)
예제 #12
0
파일: log_report.py 프로젝트: RE-ID/chainer
    def __init__(self, keys=None, trigger=(1, 'epoch'), postprocess=None,
                 log_name='log'):
        self._keys = keys
        self._trigger = trigger_module.get_trigger(trigger)
        self._postprocess = postprocess
        self._log_name = log_name
        self._log = []

        self._init_summary()
예제 #13
0
 def __init__(self, detector, image_paths, names, size, thresh=0.6,
              trigger=(1, 'epoch'), device=-1):
     self._detector = detector
     self._image_paths = image_paths
     self._names = names
     self._size = size
     self._thresh = thresh
     self._trigger = trigger_module.get_trigger(trigger)
     self._device = device
예제 #14
0
    def __init__(self, keys=None, trigger=(1, 'epoch'), postprocess=None,
                 log_name='log'):
        self._keys = keys
        self._trigger = trigger_module.get_trigger(trigger)
        self._postprocess = postprocess
        self._log_name = log_name
        self._log = []

        self._init_summary()
예제 #15
0
    def __init__(self, notifier, custom_metrics=None):
        # TODO: additional metrics calculation inside training loop
        if custom_metrics is not None:
            self.custom_metrics = custom_metrics

        if notifier is not None:
            self.notifier = notifier
        else:
            raise ValueError('Notifier is None')

        # notifier specific attributes
        self.details = None  # dict of availble details: epochs, batch_size, lr, etc
        self.starting_time = None
        self.current_epoch = 1

        # chainer specific attributes
        self._trigger_epoch = trigger_module.get_trigger((1, 'epoch'))
        self._trigger_iteration = trigger_module.get_trigger((1, 'iteration'))
예제 #16
0
 def __init__(self,
              dataset,
              crop_sizes,
              max_size_iteration=None,
              trigger=(80, 'iteration')):
     super(CropSizeUpdater, self).__init__()
     self.dataset = dataset
     self.crop_sizes = crop_sizes
     self.max_size_iteration = max_size_iteration
     self._trigger = trigger_module.get_trigger(trigger)
예제 #17
0
    def __init__(self, keys=None, trigger=(1, 'iteration'),
                 log_json_name='log', log_csv_name='log.csv'):
        self._keys = keys
        self._trigger = trigger_module.get_trigger(trigger)
        self._log_csv_name = log_csv_name
        self._log_json_name = log_json_name
        self._log = []

        self._init_summary()

        self._is_initialized = False
예제 #18
0
    def __init__(self, watch_items, trigger=(100, 'iteration'), interval=5.0):
        self._trigger = trigger_module.get_trigger(trigger)
        self._interval = interval
        self._watch_items = []
        for item in watch_items:
            if not isinstance(item, WatchItem):
                action, estimator = item
                item = WatchItem(action=action, estimator=estimator)
            self._watch_items.append(item)

        self._heartbeat_queue = Queue()
        self._stop_event = Event()
        self._heartbeat_thread = Process(target=self._heartbeat_handler)
        self._heartbeat_thread.daemon = True
예제 #19
0
    def __init__(self, y_keys, x_key='iteration', trigger=(1, 'epoch'),
                 postprocess=None, file_name='graph.png'):

        _check_available()

        self._x_key = x_key
        if isinstance(y_keys, str):
            y_keys = (y_keys,)

        self._y_keys = y_keys
        self._trigger = trigger_module.get_trigger(trigger)
        self._file_name = file_name
        self._postprocess = postprocess
        self._init_summary()
        self._data = {k: [] for k in y_keys}
예제 #20
0
    def __init__(self, keys=None, trigger=(1, 'epoch'), postprocess=None,
                 filename=None, **kwargs):
        self._keys = keys
        self._trigger = trigger_module.get_trigger(trigger)
        self._postprocess = postprocess
        self._log = []

        log_name, = argument.parse_kwargs(
            kwargs, ('log_name', 'log'),
        )
        if filename is None:
            filename = log_name
        self._log_name = filename

        self._init_summary()
 def __init__(self,
              target,
              keys,
              s_keys=[],
              mname=None,
              log_report='LogReport',
              trigger=(1, 'epoch')):
     if isinstance(target, link.Link):
         target = {'main': target}
     self._targets = target
     self._keys = keys
     self._s_keys = s_keys
     self._mname = mname
     self._log_report = log_report
     self.trigger = trigger_module.get_trigger(trigger)
예제 #22
0
    def __init__(
            self,
            writer,
            keys=None,
            trigger=(1, "epoch"),
            postprocess=None,
            log_name="log.json",
    ):
        self._writer = writer
        self._keys = keys
        self._trigger = trigger_module.get_trigger(trigger)
        self._postprocess = postprocess
        self._log_name = log_name
        self._log = []

        self._init_summary()
예제 #23
0
파일: trainer.py 프로젝트: mattya/chainer
    def __init__(self, updater, stop_trigger=None, out='result'):
        self.updater = updater
        self.stop_trigger = trigger_module.get_trigger(stop_trigger)
        self.observation = {}
        self.out = out

        reporter = reporter_module.Reporter()
        for name, optimizer in six.iteritems(updater.get_all_optimizers()):
            reporter.add_observer(name, optimizer.target)
            reporter.add_observers(
                name, optimizer.target.namedlinks(skipself=True))
        self.reporter = reporter

        self._done = False
        self._extensions = collections.OrderedDict()

        updater.connect_trainer(self)
예제 #24
0
파일: trainer.py 프로젝트: asrlabncku/RAP
    def __init__(self, updater, stop_trigger=None, out='result'):
        self.updater = updater
        self.stop_trigger = trigger_module.get_trigger(stop_trigger)
        self.observation = {}
        self.out = out

        reporter = reporter_module.Reporter()
        for name, optimizer in six.iteritems(updater.get_all_optimizers()):
            reporter.add_observer(name, optimizer.target)
            reporter.add_observers(name,
                                   optimizer.target.namedlinks(skipself=True))
        self.reporter = reporter

        self._done = False
        self._extensions = collections.OrderedDict()

        updater.connect_trainer(self)
    def __init__(self,
                 y_keys,
                 x_key='iteration',
                 trigger=(1, 'epoch'),
                 log_dir=None):
        if not os.path.isdir(log_dir):
            os.makedirs(log_dir)
        self._log_dir = log_dir

        self._writer = SummaryWriter(log_dir)

        self._x_key = x_key
        if isinstance(y_keys, str):
            y_keys = (y_keys, )

        self._y_keys = y_keys
        self._trigger = trigger_module.get_trigger(trigger)
        self._init_summary()
        self._data = {k: [] for k in y_keys}
예제 #26
0
    def __init__(self, links, statistics=default_statistics,
                 report_params=True, report_grads=True, prefix=None,
                 trigger=(1, 'epoch')):

        if not isinstance(links, (list, tuple)):
            links = links,
        self._links = links

        self._statistics = statistics

        attrs = []
        if report_params:
            attrs.append('data')
        if report_grads:
            attrs.append('grad')
        self._attrs = attrs

        self._prefix = prefix
        self._trigger = trigger_module.get_trigger(trigger)
        self._summary = reporter.DictSummary()
예제 #27
0
    def __init__(self, trigger=(10000, 'iteration'), postprocess=None,
                 segmentation_loss_key='main/loss/mask',
                 detection_loss_loc_key='main/loss/loc', detection_loss_conf_key='main/loss/conf', smooth_alpha=0.85,
                 split_alpha=0.15):

        # conduct the action of loss division
        self._trigger = trigger_module.get_trigger(trigger)
        self.alpha = smooth_alpha
        self.split_alpha = split_alpha
        self._postprocess = postprocess
        self._segmentation_loss_key = segmentation_loss_key
        self._detection_loss_conf_key = detection_loss_conf_key
        self._detection_loss_loc_key = detection_loss_loc_key

        self._max_loss_seg = None
        self._current_loss_seg = None
        self._max_loss_det_loc = None
        self._current_loss_det_loc = None
        self._max_loss_det_conf = None
        self._current_loss_det_conf = None
        self._current_loss_split = None
예제 #28
0
    def __init__(self, y_keys, x_key='iteration', trigger=(1, 'epoch'),
                 postprocess=None, file_name='plot.png', marker='x',
                 grid=True):

        _check_available()

        if not _available:
            return

        self._x_key = x_key
        if isinstance(y_keys, str):
            y_keys = (y_keys,)

        self._y_keys = y_keys
        self._trigger = trigger_module.get_trigger(trigger)
        self._file_name = file_name
        self._marker = marker
        self._grid = grid
        self._postprocess = postprocess
        self._init_summary()
        self._data = {k: [] for k in y_keys}
예제 #29
0
파일: plot_report.py 프로젝트: hvy/chainer
    def __init__(self, y_keys, x_key='iteration', trigger=(1, 'epoch'),
                 postprocess=None, filename=None, marker='x',
                 grid=True, **kwargs):

        file_name, = argument.parse_kwargs(kwargs, ('file_name', 'plot.png'))
        if filename is None:
            filename = file_name

        _check_available()

        self._x_key = x_key
        if isinstance(y_keys, str):
            y_keys = (y_keys,)

        self._y_keys = y_keys
        self._trigger = trigger_module.get_trigger(trigger)
        self._file_name = file_name
        self._marker = marker
        self._grid = grid
        self._postprocess = postprocess
        self._init_summary()
        self._data = {k: [] for k in y_keys}
    def extend(self, extension, name=None, trigger=None, priority=None,
               **kwargs):
        argument.check_unexpected_kwargs(
            kwargs,
            invoke_before_training='invoke_before_training has been removed '
            'since Chainer v2.0.0. Use initializer= instead.')
        argument.assert_kwargs_empty(kwargs)

        if name is None:
            name = getattr(extension, 'name', None)
            if name is None:
                name = getattr(extension, 'default_name', None)
                if name is None:
                    name = getattr(extension, '__name__', None)
                    if name is None:
                        raise TypeError('name is not given for the extension')
        if name == 'training':
            raise ValueError(
                'the name "training" is prohibited as an extension name')

        if trigger is None:
            trigger = getattr(extension, 'trigger', (1, 'iteration'))
        trigger = trigger_module.get_trigger(trigger)

        if priority is None:
            priority = getattr(
                extension, 'priority', extension_module.PRIORITY_READER)

        modified_name = name
        ordinal = 0
        while modified_name in self._extensions:
            ordinal += 1
            modified_name = '%s_%d' % (name, ordinal)

        extension.name = modified_name
        self._extensions[modified_name] = _ExtensionEntry(
            extension, priority, trigger)
예제 #31
0
    def extend(self,
               extension,
               name=None,
               trigger=None,
               priority=None,
               **kwargs):
        """Registers an extension to the trainer.

        :class:`Extension` is a callable object which is called after each
        update unless the corresponding trigger object decides to skip the
        iteration. The order of execution is determined by priorities:
        extensions with higher priorities are called earlier in each iteration.
        Extensions with the same priority are invoked in the order of
        registrations.

        If two or more extensions with the same name are registered, suffixes
        are added to the names of the second to last extensions. The suffix is
        ``_N`` where N is the ordinal of the extensions.

        See :class:`Extension` for the interface of extensions.

        Args:
            extension: Extension to register.
            name (str): Name of the extension. If it is omitted, the
                :attr:`Extension.name` attribute of the extension is used or
                the :attr:`Extension.default_name` attribute of the extension
                if `name` is is set to `None` or is undefined.
                Note that the name would be suffixed by an ordinal in case of
                duplicated names as explained above.
            trigger (tuple or Trigger): Trigger object that determines when to
                invoke the extension. If it is ``None``, ``extension.trigger``
                is used instead. If it is ``None`` and the extension does not
                have the trigger attribute, the extension is triggered at every
                iteration by default. If the trigger is not callable, it is
                passed to :class:`IntervalTrigger` to build an interval
                trigger.
            priority (int): Invocation priority of the extension. Extensions
                are invoked in the descending order of priorities in each
                iteration. If this is ``None``, ``extension.priority`` is used
                instead.

        """
        if kwargs:
            argument.check_unexpected_kwargs(
                kwargs,
                invoke_before_training='invoke_before_training has been '
                'removed since Chainer v2.0.0. Use initializer= instead.')
            argument.assert_kwargs_empty(kwargs)

        if name is None:
            name = getattr(extension, 'name', None)
            if name is None:
                name = getattr(extension, 'default_name', None)
                if name is None:
                    name = getattr(extension, '__name__', None)
                    if name is None:
                        raise TypeError('name is not given for the extension')
        if name == 'training':
            raise ValueError(
                'the name "training" is prohibited as an extension name')

        if trigger is None:
            trigger = getattr(extension, 'trigger', (1, 'iteration'))
        trigger = trigger_module.get_trigger(trigger)

        if priority is None:
            priority = getattr(extension, 'priority',
                               extension_module.PRIORITY_READER)

        modified_name = name
        ordinal = 0
        while modified_name in self._extensions:
            ordinal += 1
            modified_name = '%s_%d' % (name, ordinal)

        extension.name = modified_name
        self._extensions[modified_name] = _ExtensionEntry(
            extension, priority, trigger)
예제 #32
0
 def __init__(self, key, compare, trigger=(1, 'epoch')):
     self._key = key
     self._best_value = None
     self._interval_trigger = trigger_module.get_trigger(trigger)
     self._init_summary()
     self._compare = compare
예제 #33
0
 def __init__(self, trigger=(1, 'epoch'), image_generator=None):
     _check_available()
     self._trigger = trigger_module.get_trigger(trigger)
     self._fn = image_generator
     self._info_name = '.chainerui_images'
     self._infos = []
예제 #34
0
 def __init__(self, experiment, model, function, trigger=(1, 'epoch')):
     self.experiment = experiment
     self.model = model
     self.function = function
     self._trigger = trigger_module.get_trigger(trigger)
예제 #35
0
파일: trainer.py 프로젝트: mattya/chainer
    def extend(self, extension, name=None, trigger=None, priority=None,
               invoke_before_training=None):
        """Registers an extension to the trainer.

        :class:`Extension` is a callable object which is called after each
        update unless the corresponding trigger object decides to skip the
        iteration. The order of execution is determined by priorities:
        extensions with higher priorities are called earlier in each iteration.
        Extensions with the same priority are invoked in the order of
        registrations.

        If two or more extensions with the same name are registered, suffixes
        are added to the names of the second to last extensions. The suffix is
        ``_N`` where N is the ordinal of the extensions.

        See :class:`Extension` for the interface of extensions.

        Args:
            extension: Extension to register.
            name (str): Name of the extension. If it is omitted, the
                ``default_name`` attribute of the extension is used instead.
                Note that the name would be suffixed by an ordinal in case of
                duplicated names as explained above.
            trigger (tuple or Trigger): Trigger object that determines when to
                invoke the extension. If it is ``None``, ``extension.trigger``
                is used instead. If the trigger is not callable, it is passed
                to :class:`IntervalTrigger` to build an interval trigger.
            priority (int): Invocation priority of the extension. Extensions
                are invoked in the descending order of priorities in each
                iteration. If this is ``None``, ``extension.priority`` is used
                instead.
            invoke_before_training (bool or None): If ``True``, the extension
                is also invoked just before entering the training loop. If this
                is ``None``, ``extension.invoke_before_training`` is used
                instead. This option is mainly used for extensions that alter
                the training configuration (e.g., learning rates); in such a
                case, resuming from snapshots require the call of extension to
                recover the configuration before any updates.

        """
        if name is None:
            name = getattr(extension, 'name', None)
            if name is None:
                name = getattr(extension, 'default_name', None)
                if name is None:
                    raise TypeError('name is not given for the extension')
        if name == 'training':
            raise ValueError(
                'the name "training" is prohibited as an extension name')

        if trigger is None:
            trigger = getattr(extension, 'trigger', None)
        trigger = trigger_module.get_trigger(trigger)

        if priority is None:
            priority = getattr(
                extension, 'priority', extension_module.PRIORITY_READER)

        if invoke_before_training is None:
            invoke_before_training = getattr(
                extension, 'invoke_before_training', False)

        modified_name = name
        ordinal = 0
        while modified_name in self._extensions:
            ordinal += 1
            modified_name = '%s_%d' % (name, ordinal)

        extension.name = modified_name
        self._extensions[modified_name] = _ExtensionEntry(
            extension, priority, trigger, invoke_before_training)
 def pre_run(self, data, max_epochs=1):
     # Method interception to capture the max_epochs
     # max_epochs is never saved in the Engine class
     self.max_epochs = max_epochs
     self.stop_trigger = trigger_module.get_trigger((max_epochs, 'epoch'))
     Engine.run(self.engine, data, max_epochs)
예제 #37
0
 def __init__(self, shift, attr='lr', trigger=(1, 'epoch')):
     self.shift = shift
     self.attr = attr
     self.trigger = trigger_module.get_trigger(trigger)
예제 #38
0
    def extend(self,
               extension,
               name=None,
               trigger=None,
               priority=None,
               invoke_before_training=None):
        """Registers an extension to the trainer.

        :class:`Extension` is a callable object which is called after each
        update unless the corresponding trigger object decides to skip the
        iteration. The order of execution is determined by priorities:
        extensions with higher priorities are called earlier in each iteration.
        Extensions with the same priority are invoked in the order of
        registrations.

        If two or more extensions with the same name are registered, suffixes
        are added to the names of the second to last extensions. The suffix is
        ``_N`` where N is the ordinal of the extensions.

        See :class:`Extension` for the interface of extensions.

        Args:
            extension: Extension to register.
            name (str): Name of the extension. If it is omitted, the
                ``default_name`` attribute of the extension is used instead.
                Note that the name would be suffixed by an ordinal in case of
                duplicated names as explained above.
            trigger (tuple or Trigger): Trigger object that determines when to
                invoke the extension. If it is ``None``, ``extension.trigger``
                is used instead. If it is ``None`` and the extension does not
                have the trigger attribute, the extension is triggered at every
                iteration by default. If the trigger is not callable, it is
                passed to :class:`IntervalTrigger` to build an interval
                trigger.
            priority (int): Invocation priority of the extension. Extensions
                are invoked in the descending order of priorities in each
                iteration. If this is ``None``, ``extension.priority`` is used
                instead.
            invoke_before_training (bool or None): If ``True``, the extension
                is also invoked just before entering the training loop. If this
                is ``None``, ``extension.invoke_before_training`` is used
                instead. This option is mainly used for extensions that alter
                the training configuration (e.g., learning rates); in such a
                case, resuming from snapshots require the call of extension to
                recover the configuration before any updates.

        """
        if name is None:
            name = getattr(extension, 'name', None)
            if name is None:
                name = getattr(extension, 'default_name', None)
                if name is None:
                    name = getattr(extension, '__name__', None)
                    if name is None:
                        raise TypeError('name is not given for the extension')
        if name == 'training':
            raise ValueError(
                'the name "training" is prohibited as an extension name')

        if trigger is None:
            trigger = getattr(extension, 'trigger', (1, 'iteration'))
        trigger = trigger_module.get_trigger(trigger)

        if priority is None:
            priority = getattr(extension, 'priority',
                               extension_module.PRIORITY_READER)

        if invoke_before_training is None:
            invoke_before_training = getattr(extension,
                                             'invoke_before_training', False)

        modified_name = name
        ordinal = 0
        while modified_name in self._extensions:
            ordinal += 1
            modified_name = '%s_%d' % (name, ordinal)

        extension.name = modified_name
        self._extensions[modified_name] = _ExtensionEntry(
            extension, priority, trigger, invoke_before_training)
예제 #39
0
 def __init__(self, key, compare, trigger=(1, 'epoch')):
     self._key = key
     self._best_value = None
     self._interval_trigger = trigger_module.get_trigger(trigger)
     self._init_summary()
     self._compare = compare
예제 #40
0
파일: trainer.py 프로젝트: tkng/chainer
    def extend(self, extension, name=None, trigger=None, priority=None,
               **kwargs):
        """Registers an extension to the trainer.

        :class:`Extension` is a callable object which is called after each
        update unless the corresponding trigger object decides to skip the
        iteration. The order of execution is determined by priorities:
        extensions with higher priorities are called earlier in each iteration.
        Extensions with the same priority are invoked in the order of
        registrations.

        If two or more extensions with the same name are registered, suffixes
        are added to the names of the second to last extensions. The suffix is
        ``_N`` where N is the ordinal of the extensions.

        See :class:`Extension` for the interface of extensions.

        Args:
            extension: Extension to register.
            name (str): Name of the extension. If it is omitted, the
                ``default_name`` attribute of the extension is used instead.
                Note that the name would be suffixed by an ordinal in case of
                duplicated names as explained above.
            trigger (tuple or Trigger): Trigger object that determines when to
                invoke the extension. If it is ``None``, ``extension.trigger``
                is used instead. If it is ``None`` and the extension does not
                have the trigger attribute, the extension is triggered at every
                iteration by default. If the trigger is not callable, it is
                passed to :class:`IntervalTrigger` to build an interval
                trigger.
            priority (int): Invocation priority of the extension. Extensions
                are invoked in the descending order of priorities in each
                iteration. If this is ``None``, ``extension.priority`` is used
                instead.

        """
        argument.check_unexpected_kwargs(
            kwargs,
            invoke_before_training='invoke_before_training has been removed '
            'since Chainer v2.0.0. Use initializer= instead.')
        argument.assert_kwargs_empty(kwargs)

        if name is None:
            name = getattr(extension, 'name', None)
            if name is None:
                name = getattr(extension, 'default_name', None)
                if name is None:
                    name = getattr(extension, '__name__', None)
                    if name is None:
                        raise TypeError('name is not given for the extension')
        if name == 'training':
            raise ValueError(
                'the name "training" is prohibited as an extension name')

        if trigger is None:
            trigger = getattr(extension, 'trigger', (1, 'iteration'))
        trigger = trigger_module.get_trigger(trigger)

        if priority is None:
            priority = getattr(
                extension, 'priority', extension_module.PRIORITY_READER)

        modified_name = name
        ordinal = 0
        while modified_name in self._extensions:
            ordinal += 1
            modified_name = '%s_%d' % (name, ordinal)

        extension.name = modified_name
        self._extensions[modified_name] = _ExtensionEntry(
            extension, priority, trigger)