Пример #1
0
    def test_send_recv(self):
        rank = dist.get_rank()
        tensor = _build_tensor(rank + 1)
        for dest in range(0, dist.get_num_processes()):
            if dest == rank:
                continue
            dist.send(tensor, dest)

        for src in range(0, dist.get_num_processes()):
            if src == rank:
                continue
            tensor = _build_tensor(src + 1, value=-1)
            expected_tensor = _build_tensor(src + 1)
            dist.recv(tensor, src)
            self.assertEqual(tensor, expected_tensor)

        self._barrier()
Пример #2
0
    def test_send_recv_any_source(self):
        rank = dist.get_rank()
        tensor = _build_tensor(10, rank)
        for dest in range(0, dist.get_num_processes()):
            if dest == rank:
                continue
            dist.send(tensor, dest)

        recv_ranks = set()
        for src in range(0, dist.get_num_processes()):
            if src == rank:
                continue
            tensor = _build_tensor(10, value=-1)
            dist.recv(tensor)
            recv_ranks.add(tensor.resize_(1)[0])

        self.assertEqual(len(recv_ranks), dist.get_num_processes() - 1)
        self._barrier()
Пример #3
0
    def test_isend(self):
        rank = dist.get_rank()
        world_size = dist.get_num_processes()

        if rank == 0:
            requests = [
                dist.isend(_build_tensor(dest, 10), dest)
                for dest in range(1, world_size)
            ]
            for request in requests:
                request.wait()
                self.assertTrue(request.is_completed())
        else:
            tensor = _build_tensor(rank, -1)
            dist.recv(tensor, 0)
            self.assertEqual(tensor, _build_tensor(rank, 10))

        self._barrier()
Пример #4
0
    def test_get_rank(self):
        test_dir = os.path.join(TEMP_DIR, 'test_dir')
        pid = str(os.getpid())
        num_processes = dist.get_num_processes()
        with open(os.path.join(test_dir, pid), 'w') as f:
            f.write(str(dist.get_rank()))

        self._barrier()

        all_ranks = set()
        for f_name in os.listdir(test_dir):
            with open(os.path.join(test_dir, f_name), 'r') as f:
                all_ranks.add(int(f.read()))
        self.assertEqual(len(all_ranks), num_processes)

        self._barrier()

        if dist.get_rank() == 0:
            for f_name in os.listdir(test_dir):
                os.unlink(os.path.join(test_dir, f_name))

        self._barrier()
Пример #5
0
    def test_irecv(self):
        rank = dist.get_rank()
        world_size = dist.get_num_processes()

        if rank == 0:
            expected_tensors = [
                _build_tensor(src, -1) for src in range(1, world_size)
            ]
            requests = [
                dist.irecv(expected_tensors[src - 1], src)
                for src in range(1, world_size)
            ]

            for src in range(1, world_size):
                requests[src - 1].wait()
                self.assertTrue(requests[src - 1].is_completed())
                self.assertEqual(expected_tensors[src - 1],
                                 _build_tensor(src, 10))
        else:
            tensor = _build_tensor(rank, 10)
            dist.send(tensor, 0)

        self._barrier()
Пример #6
0
    def sync(cls, timeout=5):
        cls.barrier_id += 1
        barrier_dir = os.path.join(TEMP_DIR, 'barrier')
        pid = str(os.getpid())
        barrier_file = os.path.join(barrier_dir, pid)
        with _lock():
            with open(barrier_file, 'w') as f:
                f.write(str(cls.barrier_id))

        start_time = time.time()
        while True:
            arrived = 0
            with _lock():
                for f_name in os.listdir(barrier_dir):
                    with open(os.path.join(barrier_dir, f_name), 'r') as f:
                        data = f.read()
                        if int(data) >= cls.barrier_id:
                            arrived += 1
            if arrived == dist.get_num_processes():
                break

            if time.time() - start_time > timeout:
                raise RuntimeError("barrier timeout")
            time.sleep(0.1)
Пример #7
0
            end = timer()
            print_stats(bytes, num_tensors, end - start)
    print()
else:
    for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
        tensor = torch.ByteTensor(bytes).fill_(42)
        for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
            for i in range(0, num_tensors):
                dist.all_reduce(tensor)
dist.barrier()

if rank == 0:
    print_header("scatter")
    for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
        tensor = torch.ByteTensor(bytes).fill_(42)
        tensors = [tensor for n in range(0, dist.get_num_processes())]
        for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
            start = timer()
            for i in range(0, num_tensors):
                dist.scatter_send(tensors, tensor)
            end = timer()
            print_stats(bytes, num_tensors, end - start)
    print()
else:
    for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
        tensor = torch.ByteTensor(bytes).fill_(42)
        for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
            for i in range(0, num_tensors):
                dist.scatter_recv(tensor, 0)
dist.barrier()
Пример #8
0
 def _init_global_test(self):
     group = [i for i in range(0, dist.get_num_processes())]
     group_id = dist.group.WORLD
     rank = dist.get_rank()
     return (group, group_id, rank)