def _check_default_group(self): pickle_not_supported = False try: if self.process_group != dist.get_default_group(): pickle_not_supported = True except RuntimeError: pickle_not_supported = True if pickle_not_supported: raise RuntimeError("DDP Pickling/Unpickling are only supported " "when using DDP with the default process " "group. That is, when you have called " "init_process_group and have not passed " "process_group argument to DDP constructor")
def __init__(self, module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25, check_reduction=False): super(DistributedDataParallel, self).__init__() # Use all devices by default if device_ids is None: device_ids = list(range(torch.cuda.device_count())) if output_device is None: output_device = device_ids[0] if process_group is None: self.process_group = dist.get_default_group() else: self.process_group = process_group self.dim = dim self.module = module self.device_ids = list( map(lambda x: _get_device_index(x, True), device_ids)) self.output_device = _get_device_index(output_device, True) self.broadcast_buffers = broadcast_buffers self.check_reduction = check_reduction MB = 1024 * 1024 # used for intra-node param sync and inter-node sync as well self.broadcast_bucket_size = 250 * MB # reduction bucket size self.bucket_bytes_cap = bucket_cap_mb * MB # Sync params and buffers module_states = list(self.module.state_dict().values()) if len(module_states) > 0: self._dist_broadcast_coalesced(module_states, self.broadcast_bucket_size) self._ddp_init_helper()
def __setstate__(self, state): # If serializable, then the process group should be the default one self.process_group = dist.get_default_group() super(DistributedDataParallel, self).__setstate__(state) self._register_grad_hooks()
def __init__(self, module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25): super(DistributedDataParallel, self).__init__() # Use all devices by default if device_ids is None: device_ids = list(range(torch.cuda.device_count())) if output_device is None: output_device = device_ids[0] if process_group is None: self.process_group = dist.get_default_group() else: self.process_group = process_group self.dim = dim self.module = module self.device_ids = list( map(lambda x: _get_device_index(x, True), device_ids)) self.output_device = _get_device_index(output_device, True) self.broadcast_buffers = broadcast_buffers MB = 1024 * 1024 # used for intra-node param sync and inter-node sync as well self.broadcast_bucket_size = 250 * MB # Sync params and buffers module_states = list(self.module.state_dict().values()) if len(module_states) > 0: self._dist_broadcast_coalesced(module_states, self.broadcast_bucket_size) if len(device_ids) > 1: # TODO: we don't need to replicate params in here. they're always going to # be broadcasted using larger blocks in broadcast_coalesced, so it might be # better to not pollute the caches with these small blocks self._module_copies = replicate(self.module, self.device_ids, detach=True) self._module_copies[0] = self.module for module_copy in self._module_copies[1:]: for param, copy_param in zip(self.module.parameters(), module_copy.parameters()): copy_param.requires_grad = param.requires_grad else: self._module_copies = [self.module] self.modules_params_data = [[] for _ in range(len(self.device_ids))] self.modules_buffers_data = [[] for _ in range(len(self.device_ids))] for dev_idx, module in enumerate(self._module_copies): self.modules_params_data[dev_idx] = [ p.data for p in module.parameters() ] self.modules_buffers_data[dev_idx] = [ b.data for b in module.buffers() ] bucket_bytes_cap = bucket_cap_mb * MB # This is a triply-nested list where the "dimensions" are: devices, buckets, bucket_elems param_buckets = [] # Split the parameters into buckets and by types as well param_buckets = [ list(_take_tensors(m.parameters(), bucket_bytes_cap)) for m in self._module_copies ] self.bucket_sizes = [] self.bucket_map = {} # We transpose param_buckets, so the loop is over buckets. # param_buckets_tuple is a doubly-nested list with "dims": devices, bucket_elems for bucket_idx, param_buckets_tuple in enumerate(zip(*param_buckets)): self.bucket_sizes.append(0) # Now, we transpose again, so we iterate over bucket_elems, but getting tuples # of params from each device. for param_tuple in zip(*param_buckets_tuple): if not param_tuple[0].requires_grad: continue for p in param_tuple: self.bucket_map[p] = (bucket_idx, self.bucket_sizes[bucket_idx]) self.bucket_sizes[bucket_idx] += 1 self.buckets = [[[None for _ in range(self.bucket_sizes[i])] for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))] # The number of params ready in each bucket self.buckets_ready_size = [[0 for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))] # coalesced bucket for only device 0 self.buckets_coalesced = [[] for _ in range(len(self.bucket_sizes))] # We will always reduce the bucket following the reverse order # that is, alway reduces following the order of: n - 1, n - 2, ..., 0 self.next_bucket = len(self.bucket_sizes) - 1 self.ready_buckets_not_reduced = set() self.reduction_works = [None for _ in range(len(self.bucket_sizes))] self.devs_ready = [0 for _ in range(len(self.bucket_sizes))] self._register_grad_hooks()
def test_get_default_group(self): default_grp = dist.get_default_group() self.assertNotEqual(default_grp, None)