コード例 #1
0
    def __init__(self,
                 check_trigger=(1, 'epoch'),
                 monitor='main/loss',
                 patience=None,
                 mode='auto',
                 verbose=False,
                 max_trigger=(100, 'epoch'),
                 **kwargs):

        # `patients` as an alias of `patience`
        patients, = argument.parse_kwargs(kwargs, ('patients', None))
        if patients is None:
            if patience is None:
                patience = 3
            else:
                pass
        else:
            if patience is None:
                patience = patients
            else:
                raise TypeError(
                    'Both \'patience\' and \'patients\' arguments are '
                    'specified. \'patients\' is an alias of the former. '
                    'Specify only \'patience\'.')

        self.count = 0
        self.patience = patience
        self.monitor = monitor
        self.verbose = verbose
        self.already_warning = False
        self._max_trigger = util.get_trigger(max_trigger)
        self._interval_trigger = util.get_trigger(check_trigger)

        self._init_summary()

        if mode == 'max':
            self._compare = operator.gt

        elif mode == 'min':
            self._compare = operator.lt

        else:
            if 'accuracy' in monitor:
                self._compare = operator.gt

            else:
                self._compare = operator.lt

        if self._compare == operator.gt:
            if verbose:
                print('early stopping: operator is greater')
            self.best = float('-inf')

        else:
            if verbose:
                print('early stopping: operator is less')
            self.best = float('inf')
コード例 #2
0
 def __init__(self, key, compare, stock_num=5, trigger=(1, "epoch")):
     self._key = key
     self._better_values = []
     self._interval_trigger = get_trigger(trigger)
     self._init_summary()
     self._compare = compare
     self._stock_num = stock_num
コード例 #3
0
    def __init__(
            self, numerator_key, denominator_key, result_key,
            trigger=(1, 'epoch')):
        self._trigger = util.get_trigger(trigger)

        self._numerator_key = numerator_key
        self._denominator_key = denominator_key
        self._result_key = result_key
        self._numerator = 0
        self._denominator = 0
コード例 #4
0
    def setUp(self):
        self.optimizer = mock.MagicMock()
        self.optimizer.param_groups = [{'x': None}]
        self.extension = extensions.MultistepShift(
            'x', self.gamma, self.step_value, self.init, self.optimizer)

        self.interval = 1
        self.expect = [e for e in self.expect for _ in range(self.interval)]
        self.trigger = util.get_trigger((self.interval, 'iteration'))

        self.trainer = testing.get_trainer_with_mock_updater(self.trigger)
        self.trainer.updater.get_optimizer.return_value = self.optimizer
コード例 #5
0
    def setUp(self):
        self.optimizer = mock.MagicMock()
        self.optimizer.param_groups = [{'x': None}]
        self.extension = extensions.PolynomialShift('x', self.rate,
                                                    self.max_count, self.init,
                                                    self.target,
                                                    self.optimizer)

        self.interval = 4
        self.expect = [e for e in self.expect for _ in range(self.interval)]
        self.trigger = util.get_trigger((self.interval, 'iteration'))

        self.trainer = testing.get_trainer_with_mock_updater(self.trigger)
        self.trainer.updater.get_optimizer.return_value = self.optimizer
コード例 #6
0
 def __init__(self, key, compare, trigger=(1, 'epoch')):
     self._key = key
     self._best_value = None
     self._interval_trigger = util.get_trigger(trigger)
     self._init_summary()
     self._compare = compare