def test_get_rank_size_group(self): if dist.get_world_size() > 2: group = [1, 2] else: group = [0, 1] group_id = dist.new_group(group) if dist.get_rank() in group: self.assertEqual(dist.get_world_size(group_id), 2) self.assertTrue(dist.get_rank(group_id) in list(range(2))) else: self.assertEqual(dist.get_world_size(group_id), -1) self.assertEqual(dist.get_rank(group_id), -1)
def test_send_recv_any_source(self): rank = dist.get_rank() tensor = _build_tensor(10, value=rank) recv_ranks = set() for dst in range(0, dist.get_world_size()): if dst == rank: # Recv mode for dst in range(0, dist.get_world_size()): if dst == rank: continue output_tensor = _build_tensor(10, value=-1) sender = dist.recv(output_tensor) # Assert the scalar value "sender" that should be # equal to the rank of the sender is equal to all # values in the received tensor. self.assertTrue(output_tensor.eq(sender).all()) recv_ranks.add(sender) else: # Send mode dist.send(tensor, dst) self.assertEqual(len(recv_ranks), dist.get_world_size() - 1) self._barrier()
def _init_group_test(self): group = [1, 2] group_id = dist.new_group(group) rank = dist.get_rank() if rank not in group: return ([], None, rank) return (group, group_id, rank)
def test_get_rank(self): test_dir = os.path.join(TEMP_DIR, "test_dir") pid = str(os.getpid()) num_processes = dist.get_world_size() with open(os.path.join(test_dir, pid), "w") as f: f.write(str(dist.get_rank())) self._barrier() all_ranks = set() for f_name in os.listdir(test_dir): with open(os.path.join(test_dir, f_name), "r") as f: all_ranks.add(int(f.read())) self.assertEqual(len(all_ranks), num_processes) self._barrier() if dist.get_rank() == 0: for f_name in os.listdir(test_dir): os.unlink(os.path.join(test_dir, f_name)) self._barrier()
def test_send_recv(self): rank = dist.get_rank() tensor = _build_tensor(rank + 1) for dest in range(0, dist.get_world_size()): if dest == rank: continue dist.send(tensor, dest) for src in range(0, dist.get_world_size()): if src == rank: continue tensor = _build_tensor(src + 1, value=-1) expected_tensor = _build_tensor(src + 1) dist.recv(tensor, src) self.assertEqual(tensor, expected_tensor) self._barrier()
def test_isend(self): rank = dist.get_rank() world_size = dist.get_world_size() if rank == 0: requests = [ dist.isend(_build_tensor(dest, 10), dest) for dest in range(1, world_size) ] for request in requests: request.wait() self.assertTrue(request.is_completed()) else: tensor = _build_tensor(rank, -1) dist.recv(tensor, 0) self.assertEqual(tensor, _build_tensor(rank, 10)) self._barrier()
def test_send_recv_any_source(self): rank = dist.get_rank() tensor = _build_tensor(10, rank) for dest in range(0, dist.get_world_size()): if dest == rank: continue dist.send(tensor, dest) recv_ranks = set() for src in range(0, dist.get_world_size()): if src == rank: continue tensor = _build_tensor(10, value=-1) sender = dist.recv(tensor) self.assertTrue(tensor.eq(sender).all()) recv_ranks.add(sender) self.assertEqual(len(recv_ranks), dist.get_world_size() - 1) self._barrier()
def test_send_recv(self): rank = dist.get_rank() tensor = _build_tensor(rank + 1) for src in range(0, dist.get_world_size()): if src == rank: # Send mode for dst in range(0, dist.get_world_size()): if dst == rank: continue dist.send(tensor, dst) else: # Recv mode expected_tensor = _build_tensor(src + 1) output_tensor = _build_tensor(src + 1, value=-1) dist.recv(output_tensor, src) self.assertEqual(output_tensor, expected_tensor) self._barrier()
def test_irecv(self): rank = dist.get_rank() world_size = dist.get_world_size() if rank == 0: expected_tensors = [_build_tensor(src, -1) for src in range(1, world_size)] requests = [ dist.irecv(expected_tensors[src - 1], src) for src in range(1, world_size) ] for src in range(1, world_size): requests[src - 1].wait() self.assertTrue(requests[src - 1].is_completed()) self.assertEqual(expected_tensors[src - 1], _build_tensor(src, 10)) else: tensor = _build_tensor(rank, 10) dist.send(tensor, 0) self._barrier()
def _init_global_test(self): group = [i for i in range(0, dist.get_world_size())] group_id = dist.group.WORLD rank = dist.get_rank() return (group, group_id, rank)
def test_get_rank_size_full_group(self): _, group_id, _ = self._init_full_group_test() self.assertEqual(dist.get_world_size(group_id), dist.get_world_size()) self.assertEqual(dist.get_rank(group_id), dist.get_rank())
def _init_full_group_test(self): group = [i for i in range(0, dist.get_world_size())] group_id = dist.new_group() rank = dist.get_rank() return (group, group_id, rank)