def _test_scatter_helper(self, group, group_id, rank): for dest in group: tensor = _build_tensor(dest + 1, -1) expected_tensor = _build_tensor(dest + 1, rank) if rank == dest: tensors = [_build_tensor(dest + 1, i) for i in group] dist.scatter_send(tensors, tensor, group_id) self.assertEqual(tensor, expected_tensor) else: dist.scatter_recv(tensor, dest, group_id) self.assertEqual(tensor, expected_tensor) self._barrier()
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: tensor = torch.ByteTensor(bytes).fill_(42) for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]: for i in range(0, num_tensors): dist.all_reduce(tensor) dist.barrier() if rank == 0: print_header("scatter") for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: tensor = torch.ByteTensor(bytes).fill_(42) tensors = [tensor for n in range(0, dist.get_num_processes())] for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]: start = timer() for i in range(0, num_tensors): dist.scatter_send(tensors, tensor) end = timer() print_stats(bytes, num_tensors, end - start) print() else: for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: tensor = torch.ByteTensor(bytes).fill_(42) for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]: for i in range(0, num_tensors): dist.scatter_recv(tensor, 0) dist.barrier() if rank == 0: print_header("gather") for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: tensor = torch.ByteTensor(bytes).fill_(42)
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: tensor = torch.ByteTensor(bytes).fill_(42) for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]: for i in range(0, num_tensors): dist.all_reduce(tensor) dist.barrier() if rank == 0: print_header("scatter") for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: tensor = torch.ByteTensor(bytes).fill_(42) tensors = [tensor for n in range(0, dist.get_world_size())] for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]: start = timer() for i in range(0, num_tensors): dist.scatter_send(tensors, tensor) end = timer() print_stats(bytes, num_tensors, end - start) print() else: for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: tensor = torch.ByteTensor(bytes).fill_(42) for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]: for i in range(0, num_tensors): dist.scatter_recv(tensor, 0) dist.barrier() if rank == 0: print_header("gather") for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: tensor = torch.ByteTensor(bytes).fill_(42)