Exemplo n.º 1
0
    def _test_broadcast_helper(self,
                               group,
                               group_id,
                               rank,
                               cuda=False,
                               rank_to_GPU=None):
        for ttype, value, requires_cuda in [
            ("torch.FloatTensor", -1e-10, False),
            ("torch.DoubleTensor", -1e-100, False),
            ("torch.HalfTensor", -0.1, True),
            ("torch.CharTensor", -2, False),
            ("torch.ByteTensor", 129, False),
            ("torch.IntTensor", -1e5, False),
            ("torch.LongTensor", -1e15, False),
        ]:
            if requires_cuda and not cuda:
                continue
            for src in group:
                expected_tensor = _build_tensor(src + 1, value).type(ttype)
                if cuda:
                    expected_tensor = expected_tensor.cuda(
                        rank_to_GPU[rank][0])
                if rank == src:
                    dist.broadcast(expected_tensor, src, group_id)
                else:
                    tensor = _build_tensor(src + 1, -1).type(ttype)
                    if cuda:
                        tensor = tensor.cuda(rank_to_GPU[rank][0])
                    dist.broadcast(tensor, src, group_id)
                    self.assertEqual(tensor.size(), expected_tensor.size())
                    self.assertEqual(tensor.ne(expected_tensor).max(), 0)

        self._barrier()
Exemplo n.º 2
0
    def _test_barrier_helper(self, group, group_id, rank):
        WAIT_TIME = 0.3  # seconds

        for dest in group:
            expected_time = torch.DoubleTensor(1).fill_(0.0)
            if dest == rank:
                expected_time.fill_(time.time() + WAIT_TIME)
                dist.broadcast(expected_time, dest, group_id)
                time.sleep(WAIT_TIME + 0.1)  # sleep a little bit longer
                dist.barrier(group_id)
            else:
                dist.broadcast(expected_time, dest, group_id)
                dist.barrier(group_id)
                self.assertGreaterEqual(time.time(), expected_time[0])

        self._barrier()
Exemplo n.º 3
0
 def sync_parameters(self):
     for param in self.module.parameters():
         dist.broadcast(param.data, 0)