Пример #1
0
    def test_auto_shoud_be_max(self, _):
        state = {torchbearer.METRICS: {'acc_loss': 0.1}}

        file_path = 'test_file_{acc_loss:.2f}'
        check = Best(file_path, monitor='acc_loss')
        check.on_start(state)

        check.on_end_epoch(state)
        self.assertTrue(check.mode == 'max')
Пример #2
0
    def test_state_dict(self):
        check = Best('test')
        check.most_recent = 'temp'
        check.best = 'temp2'
        check.epochs_since_last_save = 10

        state = check.state_dict()

        check = Best('test')
        check.load_state_dict(state)

        self.assertEqual(check.most_recent, 'temp')
        self.assertEqual(check.best, 'temp2')
        self.assertEqual(check.epochs_since_last_save, 10)
    def test_min_delta_save(self, mock_save):
        state = {torchbearer.METRICS: {'val_loss': 0.1}}

        file_path = 'test_file_{val_loss:.2f}'
        check = Best(file_path, mode='min', min_delta=0.1)
        check.on_start(state)

        check.on_checkpoint(state)
        self.assertTrue(mock_save.call_count == 1)

        state = {torchbearer.METRICS: {'val_loss': -0.001}}
        check.on_checkpoint(state)
        self.assertTrue(mock_save.call_count == 2)
    def test_max_with_decreasing(self, mock_save):
        state = {torchbearer.METRICS: {'val_loss': 0.1}}

        file_path = 'test_file_{val_loss:.2f}'
        check = Best(file_path, mode='max')
        check.on_start(state)

        check.on_checkpoint(state)
        self.assertTrue(mock_save.call_count == 1)

        state = {torchbearer.METRICS: {'val_loss': 0.001}}
        check.on_checkpoint(state)
        self.assertTrue(mock_save.call_count == 1)
Пример #5
0
    def test_min_with_increasing(self, mock_save):
        state = {torchbearer.METRICS: {'val_loss': 0.1}}

        file_path = 'test_file_{val_loss:.2f}'
        check = Best(file_path, mode='min')
        check.on_start(state)

        check.on_end_epoch(state)
        self.assertTrue(mock_save.call_count == 1)

        state = {torchbearer.METRICS: {'val_loss': 0.2}}
        check.on_end_epoch(state)
        self.assertTrue(mock_save.call_count == 1)
    def test_auto_shoud_be_max(self, _):
        state = {torchbearer.METRICS: {'acc_loss': 0.1}}

        file_path = 'test_file_{acc_loss:.2f}'
        check = Best(file_path, monitor='acc_loss')
        check.on_start(state)

        check.on_checkpoint(state)
        self.assertTrue(check.mode == 'max')
Пример #7
0
    def test_auto_shoud_be_min(self, _):
        state = {torchbearer.METRICS: {'val_loss': 0.1}}

        file_path = 'test_file_{val_loss:.2f}'
        check = Best(file_path, monitor='val_loss')
        check.on_start(state)

        check.on_end_epoch(state)
        self.assertTrue(check.mode == 'min')
Пример #8
0
    def test_min_delta_save(self, mock_save):
        state = {torchbearer.METRICS: {'val_loss': 0.1}}

        file_path = 'test_file_{val_loss:.2f}'
        check = Best(file_path, mode='min', min_delta=0.1)
        check.on_start(state)

        check.on_end_epoch(state)
        self.assertTrue(mock_save.call_count == 1)

        state = {torchbearer.METRICS: {'val_loss': -0.001}}
        check.on_end_epoch(state)
        self.assertTrue(mock_save.call_count == 2)
Пример #9
0
    def test_max_with_decreasing(self, mock_save):
        state = {torchbearer.METRICS: {'val_loss': 0.1}}

        file_path = 'test_file_{val_loss:.2f}'
        check = Best(file_path, mode='max')
        check.on_start(state)

        check.on_end_epoch(state)
        self.assertTrue(mock_save.call_count == 1)

        state = {torchbearer.METRICS: {'val_loss': 0.001}}
        check.on_end_epoch(state)
        self.assertTrue(mock_save.call_count == 1)
Пример #10
0
    def test_bad_monitor(self, _):
        state = {torchbearer.METRICS: {'acc_loss': 0.1}}

        file_path = 'test_file_{acc_loss:.2f}'
        check = Best(file_path, monitor='test_fail')
        check.on_start(state)

        with warnings.catch_warnings(record=True) as w:
            check.on_checkpoint(state)
            self.assertTrue(len(w) == 1)