Example #1
0
 def __init__(self,
              data_source,
              batch_size=1,
              num_replicas=None,
              rank=None):
     """
     Samples batches assuming they are in order of size to batch similarly sized samples together.
     """
     super(DistributedBucketingSampler, self).__init__(data_source)
     if num_replicas is None:
         num_replicas = get_world_size()
     if rank is None:
         rank = get_rank()
     self.data_source = data_source
     self.ids = list(range(0, len(data_source)))
     self.batch_size = batch_size
     self.bins = [
         self.ids[i:i + batch_size]
         for i in range(0, len(self.ids), batch_size)
     ]
     self.num_replicas = num_replicas
     self.rank = rank
     self.num_samples = int(
         math.ceil(len(self.bins) * 1.0 / self.num_replicas))
     self.total_size = self.num_samples * self.num_replicas
Example #2
0
    def _init_group_test(self):
        group = [1, 2]
        group_id = dist.new_group(group)
        rank = dist.get_rank()
        if rank not in group:
            return ([], None, rank)

        return (group, group_id, rank)
Example #3
0
def get_dist_info():
    if dist._initialized:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    else:
        rank = 0
        world_size = 1
    return rank, world_size
Example #4
0
    def test_get_rank(self):
        test_dir = os.path.join(TEMP_DIR, "test_dir")
        pid = str(os.getpid())
        num_processes = dist.get_world_size()
        with open(os.path.join(test_dir, pid), "w") as f:
            f.write(str(dist.get_rank()))

        self._barrier()

        all_ranks = set()
        for f_name in os.listdir(test_dir):
            with open(os.path.join(test_dir, f_name), "r") as f:
                all_ranks.add(int(f.read()))
        self.assertEqual(len(all_ranks), num_processes)

        self._barrier()

        if dist.get_rank() == 0:
            for f_name in os.listdir(test_dir):
                os.unlink(os.path.join(test_dir, f_name))

        self._barrier()
Example #5
0
 def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
     if num_replicas is None:
         if not dist.is_available():
             raise RuntimeError("Requires distributed package to be available")
         num_replicas = dist.get_world_size()
     if rank is None:
         if not dist.is_available():
             raise RuntimeError("Requires distributed package to be available")
         rank = dist.get_rank()
     self.dataset = dataset
     self.num_replicas = num_replicas
     self.rank = rank
     self.epoch = 0
     self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
     self.total_size = self.num_samples * self.num_replicas
     self.shuffle = True
Example #6
0
    def test_send_recv(self):
        rank = dist.get_rank()
        tensor = _build_tensor(rank + 1)
        for dest in range(0, dist.get_world_size()):
            if dest == rank:
                continue
            dist.send(tensor, dest)

        for src in range(0, dist.get_world_size()):
            if src == rank:
                continue
            tensor = _build_tensor(src + 1, value=-1)
            expected_tensor = _build_tensor(src + 1)
            dist.recv(tensor, src)
            self.assertEqual(tensor, expected_tensor)

        self._barrier()
Example #7
0
    def test_isend(self):
        rank = dist.get_rank()
        world_size = dist.get_world_size()

        if rank == 0:
            requests = [
                dist.isend(_build_tensor(dest, 10), dest)
                for dest in range(1, world_size)
            ]
            for request in requests:
                request.wait()
                self.assertTrue(request.is_completed())
        else:
            tensor = _build_tensor(rank, -1)
            dist.recv(tensor, 0)
            self.assertEqual(tensor, _build_tensor(rank, 10))

        self._barrier()
Example #8
0
    def test_send_recv_any_source(self):
        rank = dist.get_rank()
        tensor = _build_tensor(10, rank)
        for dest in range(0, dist.get_world_size()):
            if dest == rank:
                continue
            dist.send(tensor, dest)

        recv_ranks = set()
        for src in range(0, dist.get_world_size()):
            if src == rank:
                continue
            tensor = _build_tensor(10, value=-1)
            sender = dist.recv(tensor)
            self.assertTrue(tensor.eq(sender).all())
            recv_ranks.add(sender)

        self.assertEqual(len(recv_ranks), dist.get_world_size() - 1)
        self._barrier()
Example #9
0
    def test_irecv(self):
        rank = dist.get_rank()
        world_size = dist.get_world_size()

        if rank == 0:
            expected_tensors = [_build_tensor(src, -1) for src in range(1, world_size)]
            requests = [
                dist.irecv(expected_tensors[src - 1], src)
                for src in range(1, world_size)
            ]

            for src in range(1, world_size):
                requests[src - 1].wait()
                self.assertTrue(requests[src - 1].is_completed())
                self.assertEqual(expected_tensors[src - 1], _build_tensor(src, 10))
        else:
            tensor = _build_tensor(rank, 10)
            dist.send(tensor, 0)

        self._barrier()
Example #10
0
def reduce_loss_dict(loss_dict):
    """
    Reduce the loss dictionary from all processes so that process with rank
    0 has the averaged results. Returns a dict with the same fields as
    loss_dict, after reduction.
    """
    world_size = get_world_size()
    if world_size < 2:
        return loss_dict
    with torch.no_grad():
        loss_names = []
        all_losses = []
        for k, v in loss_dict.items():
            loss_names.append(k)
            all_losses.append(v)
        all_losses = torch.stack(all_losses, dim=0)
        dist.reduce(all_losses, dst=0)
        if dist.get_rank() == 0:
            # only main process gets accumulated, so only divide by
            # world_size in this case
            all_losses /= world_size
        reduced_losses = {k: v for k, v in zip(loss_names, all_losses)}
    return reduced_losses
Example #11
0
def run():
    modell = model.CNN()
    # modell = model.AlexNet()

    size = dist.get_world_size()
    rank = dist.get_rank()

    group_list = []
    for i in range(size):
        group_list.append(i)
    group = dist.new_group(group_list)

    while (1):

        for param in modell.parameters():
            # for dst in range(1, size):
            # dist.send(param.data, dst=dst)
            dist.broadcast(param.data, src=0, group=group)

        for param in modell.parameters():
            tensor_temp = torch.zeros_like(param.data)
            dist.reduce(tensor_temp, dst=0, op=dist.reduce_op.SUM, group=group)
            param.data = tensor_temp / (size - 1)
Example #12
0
    def all_gather_stats_list(stat_list, max_size=4096):
        """
        Gather a `Statistics` list accross all processes/nodes

        Args:
            stat_list(list([`Statistics`])): list of statistics objects to
                gather accross all processes/nodes
            max_size(int): max buffer size to use

        Returns:
            our_stats(list([`Statistics`])): list of updated stats
        """
        # Get a list of world_size lists with len(stat_list) Statistics objects
        all_stats = all_gather_list(stat_list, max_size=max_size)

        our_rank = get_rank()
        our_stats = all_stats[our_rank]
        for other_rank, stats in enumerate(all_stats):
            if other_rank == our_rank:
                continue
            for i, stat in enumerate(stats):
                our_stats[i].update(stat, update_n_src_words=True)
        return our_stats
Example #13
0
def reduce_loss_dict(loss_dict):
    """
    Reduce the loss dictionary from all processes so that process with rank
    0 has the averaged results. Returns a dict with the same fields as
    loss_dict, after reduction.
    """
    world_size = get_world_size()
    if world_size < 2:
        return loss_dict
    with torch.no_grad():
        loss_names = []
        all_losses = []
        for k, v in loss_dict.items():
            loss_names.append(k)
            all_losses.append(v)
        all_losses = torch.stack(all_losses, dim=0)
        dist.reduce(all_losses, dst=0)
        if dist.get_rank() == 0:
            # only main process gets accumulated, so only divide by
            # world_size in this case
            all_losses /= world_size
        reduced_losses = {k: v for k, v in zip(loss_names, all_losses)}
    return reduced_losses
Example #14
0
def get_rank():
    if _use_c10d[0]:
        return dist_c10d.get_rank()
    else:
        return dist_no_c10d.get_rank()
Example #15
0
 def _init_global_test(self):
     group = [i for i in range(0, dist.get_world_size())]
     group_id = dist.group.WORLD
     rank = dist.get_rank()
     return (group, group_id, rank)