Beispiel #1
0
    def __init__(self,
                 check_trigger=(1, 'epoch'),
                 monitor='main/loss',
                 patience=None,
                 mode='auto',
                 verbose=False,
                 max_trigger=(100, 'epoch'),
                 **kwargs):

        # `patients` as an alias of `patience`
        patients = kwargs.get('patients', None)
        if patients is None:
            if patience is None:
                patience = 3
            else:
                pass
        else:
            if patience is None:
                patience = patients
            else:
                raise TypeError(
                    'Both \'patience\' and \'patients\' arguments are '
                    'specified. \'patients\' is an alias of the former. '
                    'Specify only \'patience\'.')

        self.count = 0
        self.patience = patience
        self.monitor = monitor
        self.verbose = verbose
        self.already_warning = False
        self._max_trigger = trigger_util.get_trigger(max_trigger)
        self._interval_trigger = trigger_util.get_trigger(check_trigger)

        self._init_summary()

        if mode == 'max':
            self._compare = operator.gt

        elif mode == 'min':
            self._compare = operator.lt

        else:
            if 'accuracy' in monitor:
                self._compare = operator.gt

            else:
                self._compare = operator.lt

        if self._compare == operator.gt:
            if verbose:
                print('early stopping: operator is greater')
            self.best = float('-inf')

        else:
            if verbose:
                print('early stopping: operator is less')
            self.best = float('inf')
Beispiel #2
0
def test_get_trigger(iters_per_epoch, trigger_args, expected):
    trainer = training.ExtensionsManager({}, [],
                                         100,
                                         iters_per_epoch=iters_per_epoch)
    trigger = trigger_util.get_trigger(trigger_args)

    # before the first iteration, trigger should be False
    for it, e in enumerate([False] + expected):
        with trainer.run_iteration():
            assert trigger(trainer) == e
Beispiel #3
0
    def __init__(self,
                 numerator_key,
                 denominator_key,
                 result_key,
                 trigger=(1, 'epoch')):
        self._trigger = trigger_util.get_trigger(trigger)

        self._numerator_key = numerator_key
        self._denominator_key = denominator_key
        self._result_key = result_key
        self._numerator = 0
        self._denominator = 0
Beispiel #4
0
from pytorch_pfn_extras.training import trigger_util
from pytorch_pfn_extras.training import triggers


@pytest.mark.parametrize(
    'iters_per_epoch,trigger_args,expected',
    [
        # Never fire trigger
        (2, None, [False, False, False, False, False, False, False]),

        # Interval trigger
        (2, (2, 'iteration'), [False, True, False, True, False, True, False]),
        (2, (2, 'epoch'), [False, False, False, True, False, False, False]),

        # Callable object
        (2, trigger_util.get_trigger(None),
         [False, False, False, False, False, False, False]),
        (2, triggers.IntervalTrigger(
            2, 'iteration'), [False, True, False, True, False, True, False]),
        (2, (lambda trainer: trainer.iteration == 3),
         [False, False, True, False, False, False, False]),
    ])
def test_get_trigger(iters_per_epoch, trigger_args, expected):
    trainer = training.ExtensionsManager({}, [],
                                         100,
                                         iters_per_epoch=iters_per_epoch)
    trigger = trigger_util.get_trigger(trigger_args)

    # before the first iteration, trigger should be False
    for it, e in enumerate([False] + expected):
        with trainer.run_iteration():
Beispiel #5
0
 def __init__(self, key, compare, trigger=(1, 'epoch')):
     self._key = key
     self._best_value = None
     self._interval_trigger = trigger_util.get_trigger(trigger)
     self._init_summary()
     self._compare = compare