def __init__(self, data_source, batch_size=1, num_replicas=None, rank=None): """ Samples batches assuming they are in order of size to batch similarly sized samples together. """ super(DistributedBucketingSampler, self).__init__(data_source) if num_replicas is None: num_replicas = get_world_size() if rank is None: rank = get_rank() self.data_source = data_source self.ids = list(range(0, len(data_source))) self.batch_size = batch_size self.bins = [ self.ids[i:i + batch_size] for i in range(0, len(self.ids), batch_size) ] self.num_replicas = num_replicas self.rank = rank self.num_samples = int( math.ceil(len(self.bins) * 1.0 / self.num_replicas)) self.total_size = self.num_samples * self.num_replicas
def _init_group_test(self): group = [1, 2] group_id = dist.new_group(group) rank = dist.get_rank() if rank not in group: return ([], None, rank) return (group, group_id, rank)
def get_dist_info(): if dist._initialized: rank = dist.get_rank() world_size = dist.get_world_size() else: rank = 0 world_size = 1 return rank, world_size
def test_get_rank(self): test_dir = os.path.join(TEMP_DIR, "test_dir") pid = str(os.getpid()) num_processes = dist.get_world_size() with open(os.path.join(test_dir, pid), "w") as f: f.write(str(dist.get_rank())) self._barrier() all_ranks = set() for f_name in os.listdir(test_dir): with open(os.path.join(test_dir, f_name), "r") as f: all_ranks.add(int(f.read())) self.assertEqual(len(all_ranks), num_processes) self._barrier() if dist.get_rank() == 0: for f_name in os.listdir(test_dir): os.unlink(os.path.join(test_dir, f_name)) self._barrier()
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) self.total_size = self.num_samples * self.num_replicas self.shuffle = True
def test_send_recv(self): rank = dist.get_rank() tensor = _build_tensor(rank + 1) for dest in range(0, dist.get_world_size()): if dest == rank: continue dist.send(tensor, dest) for src in range(0, dist.get_world_size()): if src == rank: continue tensor = _build_tensor(src + 1, value=-1) expected_tensor = _build_tensor(src + 1) dist.recv(tensor, src) self.assertEqual(tensor, expected_tensor) self._barrier()
def test_isend(self): rank = dist.get_rank() world_size = dist.get_world_size() if rank == 0: requests = [ dist.isend(_build_tensor(dest, 10), dest) for dest in range(1, world_size) ] for request in requests: request.wait() self.assertTrue(request.is_completed()) else: tensor = _build_tensor(rank, -1) dist.recv(tensor, 0) self.assertEqual(tensor, _build_tensor(rank, 10)) self._barrier()
def test_send_recv_any_source(self): rank = dist.get_rank() tensor = _build_tensor(10, rank) for dest in range(0, dist.get_world_size()): if dest == rank: continue dist.send(tensor, dest) recv_ranks = set() for src in range(0, dist.get_world_size()): if src == rank: continue tensor = _build_tensor(10, value=-1) sender = dist.recv(tensor) self.assertTrue(tensor.eq(sender).all()) recv_ranks.add(sender) self.assertEqual(len(recv_ranks), dist.get_world_size() - 1) self._barrier()
def test_irecv(self): rank = dist.get_rank() world_size = dist.get_world_size() if rank == 0: expected_tensors = [_build_tensor(src, -1) for src in range(1, world_size)] requests = [ dist.irecv(expected_tensors[src - 1], src) for src in range(1, world_size) ] for src in range(1, world_size): requests[src - 1].wait() self.assertTrue(requests[src - 1].is_completed()) self.assertEqual(expected_tensors[src - 1], _build_tensor(src, 10)) else: tensor = _build_tensor(rank, 10) dist.send(tensor, 0) self._barrier()
def reduce_loss_dict(loss_dict): """ Reduce the loss dictionary from all processes so that process with rank 0 has the averaged results. Returns a dict with the same fields as loss_dict, after reduction. """ world_size = get_world_size() if world_size < 2: return loss_dict with torch.no_grad(): loss_names = [] all_losses = [] for k, v in loss_dict.items(): loss_names.append(k) all_losses.append(v) all_losses = torch.stack(all_losses, dim=0) dist.reduce(all_losses, dst=0) if dist.get_rank() == 0: # only main process gets accumulated, so only divide by # world_size in this case all_losses /= world_size reduced_losses = {k: v for k, v in zip(loss_names, all_losses)} return reduced_losses
def run(): modell = model.CNN() # modell = model.AlexNet() size = dist.get_world_size() rank = dist.get_rank() group_list = [] for i in range(size): group_list.append(i) group = dist.new_group(group_list) while (1): for param in modell.parameters(): # for dst in range(1, size): # dist.send(param.data, dst=dst) dist.broadcast(param.data, src=0, group=group) for param in modell.parameters(): tensor_temp = torch.zeros_like(param.data) dist.reduce(tensor_temp, dst=0, op=dist.reduce_op.SUM, group=group) param.data = tensor_temp / (size - 1)
def all_gather_stats_list(stat_list, max_size=4096): """ Gather a `Statistics` list accross all processes/nodes Args: stat_list(list([`Statistics`])): list of statistics objects to gather accross all processes/nodes max_size(int): max buffer size to use Returns: our_stats(list([`Statistics`])): list of updated stats """ # Get a list of world_size lists with len(stat_list) Statistics objects all_stats = all_gather_list(stat_list, max_size=max_size) our_rank = get_rank() our_stats = all_stats[our_rank] for other_rank, stats in enumerate(all_stats): if other_rank == our_rank: continue for i, stat in enumerate(stats): our_stats[i].update(stat, update_n_src_words=True) return our_stats
def get_rank(): if _use_c10d[0]: return dist_c10d.get_rank() else: return dist_no_c10d.get_rank()
def _init_global_test(self): group = [i for i in range(0, dist.get_world_size())] group_id = dist.group.WORLD rank = dist.get_rank() return (group, group_id, rank)