def _check_balance(device_ids): imbalance_warn = """ There is an imbalance between your GPUs. You may want to exclude GPU {} which has less than 75% of the memory or cores of GPU {}. You can do so by setting the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES environment variable.""" device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) dev_props = _get_devices_properties(device_ids) def warn_imbalance(get_prop): values = [get_prop(props) for props in dev_props] min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1)) max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1)) if min_val / max_val < 0.75: warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos])) return True return False if warn_imbalance(lambda props: props.total_memory): return if warn_imbalance(lambda props: props.multi_processor_count): return
def replicate(network, devices, detach=False): if not _replicatable_module(network): raise RuntimeError("Cannot replicate network where python modules are " "childrens of ScriptModule") if not devices: return [] devices = [_get_device_index(x, True) for x in devices] num_replicas = len(devices) params = list(network.parameters()) param_indices = {param: idx for idx, param in enumerate(params)} param_copies = _broadcast_coalesced_reshape(params, devices, detach) buffers = list(network.buffers()) buffers_rg = [] buffers_not_rg = [] for buf in buffers: if buf.requires_grad and not detach: buffers_rg.append(buf) else: buffers_not_rg.append(buf) buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)} buffer_indices_not_rg = { buf: idx for idx, buf in enumerate(buffers_not_rg) } buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach) buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True) modules = list(network.modules()) module_copies = [[] for device in devices] module_indices = {} scriptmodule_skip_attr = { "_parameters", "_buffers", "_modules", "forward", "_c" } for i, module in enumerate(modules): module_indices[module] = i for j in range(num_replicas): replica = module._replicate_for_data_parallel() # This is a temporary fix for DDP. DDP needs to access the # replicated model parameters. It used to do so through # `mode.parameters()`. The fix added in #33907 for DP stops the # `parameters()` API from exposing the replicated parameters. # Hence, we add a `_former_parameters` dict here to support DDP. replica._former_parameters = OrderedDict() module_copies[j].append(replica) for i, module in enumerate(modules): for key, child in module._modules.items(): if child is None: for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = None else: module_idx = module_indices[child] for j in range(num_replicas): replica = module_copies[j][i] setattr(replica, key, module_copies[j][module_idx]) for key, param in module._parameters.items(): if param is None: for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = None else: param_idx = param_indices[param] for j in range(num_replicas): replica = module_copies[j][i] param = param_copies[j][param_idx] # parameters in replicas are no longer leaves, # so setattr them as non-parameter attributes setattr(replica, key, param) # expose the parameter for DDP replica._former_parameters[key] = param for key, buf in module._buffers.items(): if buf is None: for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = None else: if buf.requires_grad and not detach: buffer_copies = buffer_copies_rg buffer_idx = buffer_indices_rg[buf] else: buffer_copies = buffer_copies_not_rg buffer_idx = buffer_indices_not_rg[buf] for j in range(num_replicas): replica = module_copies[j][i] setattr(replica, key, buffer_copies[j][buffer_idx]) return [module_copies[j][0] for j in range(num_replicas)]
def __init__(self, module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25, find_unused_parameters=False, check_reduction=False): super(DistributedDataParallel, self).__init__() assert any((p.requires_grad for p in module.parameters())), ( "DistributedDataParallel is not needed when a module " "doesn't have any parameter that requires a gradient.") self.is_multi_device_module = len( {p.device for p in module.parameters()}) > 1 distinct_device_types = {p.device.type for p in module.parameters()} assert len(distinct_device_types) == 1, ( "DistributedDataParallel's input module must be on " "the same type of devices, but input module parameters locate in {}." ).format(distinct_device_types) self.device_type = list(distinct_device_types)[0] if self.device_type == "cpu" or self.is_multi_device_module: assert not device_ids and not output_device, ( "DistributedDataParallel device_ids and output_device arguments " "only work with single-device GPU modules, but got " "device_ids {}, output_device {}, and module parameters {}." ).format(device_ids, output_device, {p.device for p in module.parameters()}) self.device_ids = None self.output_device = None else: # Use all devices by default for single-device GPU modules if device_ids is None: device_ids = _get_all_device_indices() self.device_ids = list( map(lambda x: _get_device_index(x, True), device_ids)) if output_device is None: output_device = device_ids[0] self.output_device = _get_device_index(output_device, True) if process_group is None: self.process_group = _get_default_group() else: self.process_group = process_group self.dim = dim self.module = module self.broadcast_buffers = broadcast_buffers self.find_unused_parameters = find_unused_parameters self.require_backward_grad_sync = True self.require_forward_param_sync = True if check_reduction: # This argument is no longer used since the reducer # will ensure reduction completes even if some parameters # do not receive gradients. pass # used for intra-node param sync and inter-node sync as well self.broadcast_bucket_size = int(250 * 1024 * 1024) # reduction bucket size self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024) # Sync params and buffers module_states = list(self.module.state_dict().values()) if len(module_states) > 0: self._distributed_broadcast_coalesced(module_states, self.broadcast_bucket_size) self._ddp_init_helper()