def test_partial(self):
     im = IterationManager(
         num_epochs=2,
         edge_paths=["A", "B", "C"],
         num_edge_chunks=4,
         iteration_idx=(1 * 3 + 2) * 4 + 3,
     )
     self.assertEqual(list(im), [(1, 2, 3)])
 def test_properties(self):
     im = IterationManager(
         num_epochs=2,
         edge_paths=["A", "B", "C"],
         num_edge_chunks=4,
         iteration_idx=(0 * 3 + 1) * 4 + 2,
     )
     self.assertEqual(im.epoch_idx, 0)
     self.assertEqual(im.edge_path_idx, 1)
     self.assertEqual(im.edge_path, "B")
     self.assertEqual(im.edge_chunk_idx, 2)
     self.assertEqual(
         im.get_checkpoint_metadata(),
         {
             "iteration/num_epochs": 2,
             "iteration/epoch_idx": 0,
             "iteration/num_edge_paths": 3,
             "iteration/edge_path_idx": 1,
             "iteration/edge_path": "B",
             "iteration/num_edge_chunks": 4,
             "iteration/edge_chunk_idx": 2,
         },
     )
 def test_tampering(self):
     im = IterationManager(num_epochs=2,
                           edge_paths=["A", "B", "C"],
                           num_edge_chunks=4)
     it = iter(im)
     self.assertEqual(next(it), (0, 0, 0))
     im.iteration_idx = (0 * 3 + 1) * 4 + 1
     # When calling next it gets incremented.
     self.assertEqual(next(it), (0, 1, 2))
     im.edge_paths = ["foo", "bar"]
     im.num_edge_chunks = 2
     self.assertEqual(next(it), (1, 1, 1))
     im.iteration_idx = 100
     with self.assertRaises(StopIteration):
         next(it)
 def test_full(self):
     im = IterationManager(num_epochs=2,
                           edge_paths=["A", "B", "C"],
                           num_edge_chunks=4)
     self.assertEqual(list(im), list(product(range(2), range(3), range(4))))