def test_interval_is_1(self, mock_save_check): state = {} check = Interval('test_file', period=1) check.on_checkpoint(state) check.on_checkpoint(state) self.assertTrue(mock_save_check.call_count == 2)
def test_interval_on_batch(self, mock_save_check): state = {} check = Interval('test_file', period=4, on_batch=True) for i in range(13): check.on_step_training(state) if i == 3: self.assertTrue(mock_save_check.call_count == 1) elif i == 6: self.assertFalse(mock_save_check.call_count == 2) elif i == 7: self.assertTrue(mock_save_check.call_count == 2) check.on_checkpoint(state) self.assertTrue(mock_save_check.call_count == 3)
def test_interval_is_more_than_1(self, mock_save_check): state = {} check = Interval('test_file', period=4) for i in range(13): check.on_checkpoint(state) if i == 3: self.assertTrue(mock_save_check.call_count == 1) elif i == 6: self.assertFalse(mock_save_check.call_count == 2) elif i == 7: self.assertTrue(mock_save_check.call_count == 2) self.assertTrue(mock_save_check.call_count == 3)