示例#1
0
 def test_skips_full_epoch(self):
     epoch_steps = trax.epochs(total_steps=4,
                               steps_to_skip=2,
                               epoch_steps=[2, 2])
     self.assertEqual(list(epoch_steps), [2])
示例#2
0
 def test_skips_part_of_epoch(self):
     epoch_steps = trax.epochs(total_steps=4,
                               steps_to_skip=1,
                               epoch_steps=[2, 2])
     self.assertEqual(list(epoch_steps), [1, 2])
示例#3
0
 def test_cuts_epoch_when_total_steps_reached(self):
     epoch_steps = trax.epochs(total_steps=5,
                               steps_to_skip=0,
                               epoch_steps=[1, 2, 3])
     self.assertEqual(list(epoch_steps), [1, 2, 2])