def test_state_dict(self): check = Best('test') check.best = 'temp2' check.epochs_since_last_save = 10 state = check.state_dict() check = Best('test') check.load_state_dict(state) self.assertEqual(check.best, 'temp2') self.assertEqual(check.epochs_since_last_save, 10)
def test_auto_shoud_be_max(self, _): state = {torchbearer.METRICS: {'acc_loss': 0.1}} file_path = 'test_file_{acc_loss:.2f}' check = Best(file_path, monitor='acc_loss') check.on_start(state) check.on_checkpoint(state) self.assertTrue(check.mode == 'max')
def test_auto_shoud_be_min(self, _): state = {torchbearer.METRICS: {'val_loss': 0.1}} file_path = 'test_file_{val_loss:.2f}' check = Best(file_path, monitor='val_loss') check.on_start(state) check.on_end_epoch(state) self.assertTrue(check.mode == 'min')
def test_bad_monitor(self, _): state = {torchbearer.METRICS: {'acc_loss': 0.1}} file_path = 'test_file_{acc_loss:.2f}' check = Best(file_path, monitor='test_fail') check.on_start(state) with warnings.catch_warnings(record=True) as w: check.on_checkpoint(state) self.assertTrue(len(w) == 1)
def test_min_delta_save(self, mock_save): state = {torchbearer.METRICS: {'val_loss': 0.1}} file_path = 'test_file_{val_loss:.2f}' check = Best(file_path, mode='min', min_delta=0.1) check.on_start(state) check.on_checkpoint(state) self.assertTrue(mock_save.call_count == 1) state = {torchbearer.METRICS: {'val_loss': -0.001}} check.on_checkpoint(state) self.assertTrue(mock_save.call_count == 2)
def test_max_with_decreasing(self, mock_save): state = {torchbearer.METRICS: {'val_loss': 0.1}} file_path = 'test_file_{val_loss:.2f}' check = Best(file_path, mode='max') check.on_start(state) check.on_checkpoint(state) self.assertTrue(mock_save.call_count == 1) state = {torchbearer.METRICS: {'val_loss': 0.001}} check.on_checkpoint(state) self.assertTrue(mock_save.call_count == 1)
def test_min_with_increasing(self, mock_save): state = {torchbearer.METRICS: {'val_loss': 0.1}} file_path = 'test_file_{val_loss:.2f}' check = Best(file_path, mode='min') check.on_start(state) check.on_end_epoch(state) self.assertTrue(mock_save.call_count == 1) state = {torchbearer.METRICS: {'val_loss': 0.2}} check.on_end_epoch(state) self.assertTrue(mock_save.call_count == 1)