def test_allreduce(self): device = xm.xla_device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() pg_xla = get_process_group_xla(rank=511, size=1024) opts = dist.AllreduceOptions() opts.reduceOp = dist.ReduceOp.SUM all_reduce_pattern = r'%all\-reduce\.\d+ = .+ all\-reduce\(' with xm_cc_op_intercepted('all_reduce'): pg_xla.allreduce([tensor], opts) hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor]) hlo_matches(hlo, all_reduce_pattern) # purge all computations attached the device. xm.mark_step()
def _test_allreduce_basics(self, fn): store = c10d.FileStore(self.file_name, self.world_size) pg = c10d.ProcessGroupCCL(store, self.rank, self.world_size) # Single input tests tests = simple_reduce_tests(self.rank, self.world_size) for (op, input, output) in tests: opts = c10d.AllreduceOptions() opts.reduceOp = op tensor = fn(input) work = pg.allreduce([tensor], opts) work.wait() # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType(output, tensor)
def _register_grad_hooks(self): self._grad_accs = [] # need to keep them in scope # default stream tracking to launch nccl reduce kernels self.default_streams = [] for dev_id in self.device_ids: with torch.cuda.device(dev_id): self.default_streams.append(torch.cuda.current_stream()) self.allreduce_opts = dist.AllreduceOptions() for device_idx, module in enumerate(self._module_copies): for p in module.parameters(): if p.requires_grad: p_tmp = p.expand_as(p) grad_acc = p_tmp.grad_fn.next_functions[0][0] grad_acc.register_hook(self._make_param_hook(p, device_idx)) self._grad_accs.append(grad_acc)
def test_allreduce_with_mesh(self): device = xm.xla_device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() set_world_size(6) ranks = [2, 3] world_rank = 3 set_world_rank(world_rank) with new_group_barrier_disabled(): new_pg = dist.new_group(ranks=ranks) opts = dist.AllreduceOptions() opts.reduceOp = dist.ReduceOp.SUM all_reduce_pattern = (r'%all\-reduce\.\d+ = .+ all\-reduce\(.+\), .*' r'replica_groups=\{\{0,1\},\{2,3\},\{4,5\}\}') with xm_cc_op_intercepted('all_reduce'): new_pg.allreduce([tensor], opts) hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor]) hlo_matches(hlo, all_reduce_pattern) # purge all computations attached the device. xm.mark_step()
def allreduce(tensors, op): opts = c10d.AllreduceOptions() opts.reduceOp = op work = pg.allreduce(tensors, opts) work.wait()
def allreduce(x, op): opts = c10d.AllreduceOptions() opts.reduceOp = op work = pg.allreduce([x], opts) work.wait()
def __init__(self, module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25): super(DistributedDataParallel, self).__init__() # Use all devices by default if device_ids is None: device_ids = list(range(torch.cuda.device_count())) if output_device is None: output_device = device_ids[0] if process_group is None: self.process_group = dist.get_default_group() else: self.process_group = process_group self.dim = dim self.module = module self.device_ids = device_ids self.output_device = output_device self.broadcast_buffers = broadcast_buffers self.allreduce_opts = dist.AllreduceOptions() MB = 1024 * 1024 # used for intra-node param sync and inter-node sync as well self.broadcast_bucket_size = 25 * MB # Sync params and buffers module_states = list(self.module.state_dict().values()) if len(module_states) > 0: self._dist_broadcast_coalesced(module_states, self.broadcast_bucket_size) if len(device_ids) > 1: # TODO: we don't need to replicate params in here. they're always going to # be broadcasted using larger blocks in broadcast_coalesced, so it might be # better to not pollute the caches with these small blocks self._module_copies = replicate(self.module, self.device_ids, detach=True) self._module_copies[0] = self.module for module_copy in self._module_copies[1:]: for param, copy_param in zip(self.module.parameters(), module_copy.parameters()): copy_param.requires_grad = param.requires_grad else: self._module_copies = [self.module] self.modules_params_data = [[] for _ in range(len(self.device_ids))] self.modules_buffers_data = [[] for _ in range(len(self.device_ids))] for dev_idx, module in enumerate(self._module_copies): self.modules_params_data[dev_idx] = [ p.data for p in module.parameters() ] self.modules_buffers_data[dev_idx] = [ b.data for b in module.buffers() ] bucket_bytes_cap = bucket_cap_mb * MB # This is a triply-nested list where the "dimensions" are: devices, buckets, bucket_elems param_buckets = [] # Split the parameters into buckets and by types as well param_buckets = [ list(_take_tensors(m.parameters(), bucket_bytes_cap)) for m in self._module_copies ] self.bucket_sizes = [] self.bucket_map = {} # We transpose param_buckets, so the loop is over buckets. # param_buckets_tuple is a doubly-nested list with "dims": devices, bucket_elems for bucket_idx, param_buckets_tuple in enumerate(zip(*param_buckets)): self.bucket_sizes.append(0) # Now, we transpose again, so we iterate over bucket_elems, but getting tuples # of params from each device. for idx, param_tuple in enumerate(zip(*param_buckets_tuple)): if not param_tuple[0].requires_grad: continue for p in param_tuple: self.bucket_map[p] = (bucket_idx, idx) self.bucket_sizes[bucket_idx] += 1 self.buckets = [[[None for _ in range(self.bucket_sizes[i])] for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))] # The number of params ready in each bucket self.buckets_ready_size = [[0 for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))] # coalesced bucket for only device 0 self.buckets_coalesced = [[] for _ in range(len(self.bucket_sizes))] # We will always reduce the bucket following the reverse order # that is, alway reduces following the order of: n - 1, n - 2, ..., 0 self.next_bucket = len(self.bucket_sizes) - 1 self.ready_buckets_not_reduced = set() self.reduction_works = [None for _ in range(len(self.bucket_sizes))] self.devs_ready = [0 for _ in range(len(self.bucket_sizes))] # default stream tracking to launch nccl reduce kernels self.default_streams = [] for dev_id in self.device_ids: with torch.cuda.device(dev_id): self.default_streams.append(torch.cuda.current_stream()) self._register_grad_hooks()