def test_interval_is_1(self, mock_save_check):
        state = {}
        check = Interval('test_file', period=1)

        check.on_checkpoint(state)
        check.on_checkpoint(state)

        self.assertTrue(mock_save_check.call_count == 2)
    def test_interval_on_batch(self, mock_save_check):
        state = {}
        check = Interval('test_file', period=4, on_batch=True)

        for i in range(13):
            check.on_step_training(state)
            if i == 3:
                self.assertTrue(mock_save_check.call_count == 1)
            elif i == 6:
                self.assertFalse(mock_save_check.call_count == 2)
            elif i == 7:
                self.assertTrue(mock_save_check.call_count == 2)
        check.on_checkpoint(state)
        self.assertTrue(mock_save_check.call_count == 3)
    def test_interval_is_more_than_1(self, mock_save_check):
        state = {}
        check = Interval('test_file', period=4)

        for i in range(13):
            check.on_checkpoint(state)
            if i == 3:
                self.assertTrue(mock_save_check.call_count == 1)
            elif i == 6:
                self.assertFalse(mock_save_check.call_count == 2)
            elif i == 7:
                self.assertTrue(mock_save_check.call_count == 2)

        self.assertTrue(mock_save_check.call_count == 3)