def test_send_recv(self): rank = dist.get_rank() tensor = _build_tensor(rank + 1) for dest in range(0, dist.get_num_processes()): if dest == rank: continue dist.send(tensor, dest) for src in range(0, dist.get_num_processes()): if src == rank: continue tensor = _build_tensor(src + 1, value=-1) expected_tensor = _build_tensor(src + 1) dist.recv(tensor, src) self.assertEqual(tensor, expected_tensor) self._barrier()
def test_send_recv_any_source(self): rank = dist.get_rank() tensor = _build_tensor(10, rank) for dest in range(0, dist.get_num_processes()): if dest == rank: continue dist.send(tensor, dest) recv_ranks = set() for src in range(0, dist.get_num_processes()): if src == rank: continue tensor = _build_tensor(10, value=-1) dist.recv(tensor) recv_ranks.add(tensor.resize_(1)[0]) self.assertEqual(len(recv_ranks), dist.get_num_processes() - 1) self._barrier()
def test_isend(self): rank = dist.get_rank() world_size = dist.get_num_processes() if rank == 0: requests = [ dist.isend(_build_tensor(dest, 10), dest) for dest in range(1, world_size) ] for request in requests: request.wait() self.assertTrue(request.is_completed()) else: tensor = _build_tensor(rank, -1) dist.recv(tensor, 0) self.assertEqual(tensor, _build_tensor(rank, 10)) self._barrier()
def test_get_rank(self): test_dir = os.path.join(TEMP_DIR, 'test_dir') pid = str(os.getpid()) num_processes = dist.get_num_processes() with open(os.path.join(test_dir, pid), 'w') as f: f.write(str(dist.get_rank())) self._barrier() all_ranks = set() for f_name in os.listdir(test_dir): with open(os.path.join(test_dir, f_name), 'r') as f: all_ranks.add(int(f.read())) self.assertEqual(len(all_ranks), num_processes) self._barrier() if dist.get_rank() == 0: for f_name in os.listdir(test_dir): os.unlink(os.path.join(test_dir, f_name)) self._barrier()
def test_irecv(self): rank = dist.get_rank() world_size = dist.get_num_processes() if rank == 0: expected_tensors = [ _build_tensor(src, -1) for src in range(1, world_size) ] requests = [ dist.irecv(expected_tensors[src - 1], src) for src in range(1, world_size) ] for src in range(1, world_size): requests[src - 1].wait() self.assertTrue(requests[src - 1].is_completed()) self.assertEqual(expected_tensors[src - 1], _build_tensor(src, 10)) else: tensor = _build_tensor(rank, 10) dist.send(tensor, 0) self._barrier()
def sync(cls, timeout=5): cls.barrier_id += 1 barrier_dir = os.path.join(TEMP_DIR, 'barrier') pid = str(os.getpid()) barrier_file = os.path.join(barrier_dir, pid) with _lock(): with open(barrier_file, 'w') as f: f.write(str(cls.barrier_id)) start_time = time.time() while True: arrived = 0 with _lock(): for f_name in os.listdir(barrier_dir): with open(os.path.join(barrier_dir, f_name), 'r') as f: data = f.read() if int(data) >= cls.barrier_id: arrived += 1 if arrived == dist.get_num_processes(): break if time.time() - start_time > timeout: raise RuntimeError("barrier timeout") time.sleep(0.1)
end = timer() print_stats(bytes, num_tensors, end - start) print() else: for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: tensor = torch.ByteTensor(bytes).fill_(42) for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]: for i in range(0, num_tensors): dist.all_reduce(tensor) dist.barrier() if rank == 0: print_header("scatter") for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: tensor = torch.ByteTensor(bytes).fill_(42) tensors = [tensor for n in range(0, dist.get_num_processes())] for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]: start = timer() for i in range(0, num_tensors): dist.scatter_send(tensors, tensor) end = timer() print_stats(bytes, num_tensors, end - start) print() else: for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: tensor = torch.ByteTensor(bytes).fill_(42) for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]: for i in range(0, num_tensors): dist.scatter_recv(tensor, 0) dist.barrier()
def _init_global_test(self): group = [i for i in range(0, dist.get_num_processes())] group_id = dist.group.WORLD rank = dist.get_rank() return (group, group_id, rank)