def test_logging(self): # Assert initialization resets comm.get() stats self.assertEqual(comm.get().comm_rounds, 0) self.assertEqual(comm.get().comm_bytes, 0) # Test verbosity True setting and logging comm.get().set_verbosity(True) sizes = [(), (1,), (5,), (5, 5), (5, 5, 5)] # Test send / recv: for size in sizes: tensor = get_random_test_tensor(size=size, is_float=False) crypten.reset_communication_stats() # Send forward, receive backward dst = (self.rank + 1) % self.world_size src = (self.rank - 1) % self.world_size if self.rank == 0: comm.get().send(tensor, dst=dst) tensor = comm.get().recv(tensor, src=src) if self.rank > 0: comm.get().send(tensor, dst=dst) self.assertEqual(comm.get().comm_rounds, 2) self.assertEqual(comm.get().comm_bytes, tensor.numel() * 8 * 2) # Test all other ops: ops = ["all_reduce", "all_gather", "broadcast", "gather", "reduce", "scatter"] for size in sizes: for op in ops: tensor = get_random_test_tensor(size=size, is_float=False) bytes = tensor.numel() * 8 crypten.reset_communication_stats() # Setup op-specific kwargs / inputs args = () if op in ["gather", "reduce"]: args = (0,) # dst arg if op == "broadcast": args = (0,) # dst arg if op == "scatter": tensor = [tensor] * self.world_size args = (0,) # src arg tensor = getattr(comm.get(), op)(tensor, *args) self.assertEqual(comm.get().comm_rounds, 1) self.assertEqual(comm.get().comm_bytes, bytes * (self.world_size - 1)) # Test reset_communication_stats crypten.reset_communication_stats() self.assertEqual(comm.get().comm_rounds, 0) self.assertEqual(comm.get().comm_bytes, 0) # Test verbosity False setting and no logging comm.get().set_verbosity(False) tensor = get_random_test_tensor(size=size, is_float=False) tensor = comm.get().broadcast(tensor, src=0) self.assertEqual(comm.get().comm_rounds, 0) self.assertEqual(comm.get().comm_bytes, 0)
def test_logging(self): # Assert initialization resets comm.get() stats self.assertEqual(comm.get().comm_rounds, 0) self.assertEqual(comm.get().comm_bytes, 0) # Test verbosity True setting and logging cfg.communicator.verbose = True sizes = [(), (1,), (5,), (5, 5), (5, 5, 5)] # Test send / recv: for size in sizes: tensor = get_random_test_tensor(size=size, is_float=False) crypten.reset_communication_stats() # Send forward, receive backward dst = (self.rank + 1) % self.world_size src = (self.rank - 1) % self.world_size if self.rank == 0: comm.get().send(tensor, dst=dst) tensor = comm.get().recv(tensor, src=src) if self.rank > 0: comm.get().send(tensor, dst=dst) self.assertEqual(comm.get().comm_rounds, 2) self.assertEqual(comm.get().comm_bytes, tensor.numel() * 8 * 2) # Test all other ops: ops = ["all_reduce", "all_gather", "broadcast", "gather", "reduce", "scatter"] for size in sizes: for op in ops: tensor = get_random_test_tensor(size=size, is_float=False) nbytes = tensor.numel() * 8 crypten.reset_communication_stats() # Setup op-specific kwargs / inputs args = () if op in ["gather", "reduce"]: args = (0,) # dst arg if op == "broadcast": args = (0,) # dst arg if op == "scatter": tensor = [tensor] * self.world_size args = (0,) # src arg tensor = getattr(comm.get(), op)(tensor, *args) self.assertEqual(comm.get().comm_rounds, 1) if op in ["all_reduce", "all_gather"]: reference = 2 * nbytes * (self.world_size - 1) else: reference = nbytes * (self.world_size - 1) self.assertEqual(comm.get().comm_bytes, reference) # Test reset_communication_stats crypten.reset_communication_stats() self.assertEqual(comm.get().comm_rounds, 0) self.assertEqual(comm.get().comm_bytes, 0) # test retrieving communication stats: stats = comm.get().get_communication_stats() self.assertIsInstance(stats, dict) for key in ["rounds", "bytes", "time"]: self.assertIn(key, stats) self.assertEqual(stats[key], 0) # Test verbosity False setting and no logging cfg.communicator.verbose = False tensor = get_random_test_tensor(size=size, is_float=False) tensor = comm.get().broadcast(tensor, 0) self.assertEqual(comm.get().comm_rounds, 0) self.assertEqual(comm.get().comm_bytes, 0)