def test_skips_full_epoch(self): epoch_steps = trainer_lib.epochs(total_steps=4, steps_to_skip=2, epoch_steps=[2, 2]) self.assertEqual(list(epoch_steps), [2])
def test_skips_part_of_epoch(self): epoch_steps = trainer_lib.epochs(total_steps=4, steps_to_skip=1, epoch_steps=[2, 2]) self.assertEqual(list(epoch_steps), [1, 2])
def test_cuts_epoch_when_total_steps_reached(self): epoch_steps = trainer_lib.epochs(total_steps=5, steps_to_skip=0, epoch_steps=[1, 2, 3]) self.assertEqual(list(epoch_steps), [1, 2, 2])