def test_auto_should_be_max(self):
        state = {
            torchbearer.EPOCH: 1,
            torchbearer.STOP_TRAINING: False,
            torchbearer.METRICS: {'acc_metric': 0.001}
        }

        stopper = EarlyStopping(monitor='acc_metric')

        stopper.on_start(state)
        stopper.on_end_epoch(state)

        self.assertTrue(stopper.mode == 'max')
    def test_state_dict(self):
        stopper = EarlyStopping(monitor='test_metric_1')
        stopper.wait = 10
        stopper.best = 20
        state = stopper.state_dict()

        stopper = EarlyStopping(monitor='test_metric_1')
        self.assertNotEqual(stopper.wait, 10)

        stopper.load_state_dict(state)
        self.assertEqual(stopper.wait, 10)
        self.assertEqual(stopper.best, 20)
    def test_step_on_batch(self):
        stopper = EarlyStopping(monitor='test_metric',
                                mode='min',
                                step_on_batch=True)

        stopper.step = MagicMock()

        stopper.on_step_training('test')
        self.assertTrue(stopper.step.call_count == 1)

        stopper.on_end_epoch('test')
        self.assertTrue(stopper.step.call_count == 1)
示例#4
0
    def test_max_equal_should_stop(self):
        state = {
            torchbearer.EPOCH: 1,
            torchbearer.STOP_TRAINING: False,
            torchbearer.METRICS: {
                'test_metric': 0.001
            }
        }

        stopper = EarlyStopping(monitor='test_metric', mode='max')

        stopper.on_start(state)
        stopper.on_end_epoch(state)

        self.assertFalse(state[torchbearer.STOP_TRAINING])

        stopper.on_end_epoch(state)

        self.assertTrue(state[torchbearer.STOP_TRAINING])
示例#5
0
    def test_patience_should_stop(self):
        state = {
            torchbearer.EPOCH: 1,
            torchbearer.STOP_TRAINING: False,
            torchbearer.METRICS: {
                'test_metric': 0.001
            }
        }

        stopper = EarlyStopping(monitor='test_metric', patience=3)

        stopper.on_start(state)

        for i in range(3):
            stopper.on_end_epoch(state)
            self.assertFalse(state[torchbearer.STOP_TRAINING])

        stopper.on_end_epoch(state)
        self.assertTrue(state[torchbearer.STOP_TRAINING])
    def test_max_equal_should_stop(self):
        state = {
            torchbearer.EPOCH: 1,
            torchbearer.STOP_TRAINING: False,
            torchbearer.METRICS: {'test_metric': 0.001}
        }

        stopper = EarlyStopping(monitor='test_metric', mode='max')

        stopper.on_start(state)
        stopper.on_end_epoch(state)

        self.assertFalse(state[torchbearer.STOP_TRAINING])

        stopper.on_end_epoch(state)

        self.assertTrue(state[torchbearer.STOP_TRAINING])
    def test_patience_should_stop(self):
        state = {
            torchbearer.EPOCH: 1,
            torchbearer.STOP_TRAINING: False,
            torchbearer.METRICS: {'test_metric': 0.001}
        }

        stopper = EarlyStopping(monitor='test_metric', patience=3)

        stopper.on_start(state)

        for i in range(3):
            stopper.on_end_epoch(state)
            self.assertFalse(state[torchbearer.STOP_TRAINING])

        stopper.on_end_epoch(state)
        self.assertTrue(state[torchbearer.STOP_TRAINING])
示例#8
0
    def test_min_should_continue(self):
        state = {
            torchbearer.EPOCH: 1,
            torchbearer.STOP_TRAINING: False,
            torchbearer.METRICS: {
                'test_metric': 0.001
            }
        }

        stopper = EarlyStopping(monitor='test_metric', mode='min')

        stopper.on_start(state)
        stopper.on_end_epoch(state)

        self.assertFalse(state[torchbearer.STOP_TRAINING])

        state[torchbearer.METRICS]['test_metric'] = 0.0001

        stopper.on_end_epoch(state)

        self.assertFalse(state[torchbearer.STOP_TRAINING])
    def test_min_delta_should_continue(self):
        state = {
            torchbearer.EPOCH: 1,
            torchbearer.STOP_TRAINING: False,
            torchbearer.METRICS: {'test_metric': 0.001}
        }

        stopper = EarlyStopping(monitor='test_metric', mode='max', min_delta=0.1)

        stopper.on_start(state)
        stopper.on_end_epoch(state)

        self.assertFalse(state[torchbearer.STOP_TRAINING])

        state[torchbearer.METRICS]['test_metric'] = 0.102
        stopper.on_end_epoch(state)

        self.assertFalse(state[torchbearer.STOP_TRAINING])
示例#10
0
    def test_auto_should_be_max(self):
        state = {
            torchbearer.EPOCH: 1,
            torchbearer.STOP_TRAINING: False,
            torchbearer.METRICS: {
                'acc_metric': 0.001
            }
        }

        stopper = EarlyStopping(monitor='acc_metric')

        stopper.on_start(state)
        stopper.on_end_epoch(state)

        self.assertTrue(stopper.mode == 'max')