def __init__(self, module, device_ids=None, output_device=None, dim=0): super(torch.nn.DataParallel, self).__init__() #super().__init__(module, device_ids, output_device, dim) if not torch.cuda.is_available(): self.module = module self.device_ids = [] return if device_ids is None: device_ids = list(range(torch.cuda.device_count())) if output_device is None: output_device = device_ids[0] self.dim = dim self.module = module self.device_ids = list( map(lambda x: _get_device_index(x, True), device_ids)) self.output_device = _get_device_index(output_device, True) self.src_device_obj = torch.device("cuda:{}".format( self.device_ids[0])) #print("module:", module, "dim:", dim, "device_ids:", self.device_ids, \ # "output_device:", self.output_device, "src_device_obj:", self.src_device_obj) _check_balance(self.device_ids) if len(self.device_ids) == 1: #print("len(self.device_ids):", len(self.device_ids)) self.module.cuda(device_ids[0])
def __init__(self, module, device_ids=None, output_device=None, dim=0): super(DataParallel, self).__init__() if not torch.cuda.is_available(): self.module = module self.device_ids = [] return if device_ids is None: device_ids = list(range(torch.cuda.device_count())) if output_device is None: output_device = device_ids[0] self.dim = dim self.module = module self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) self.output_device = _get_device_index(output_device, True) self.src_device_obj = torch.device("cuda:{}".format(self.device_ids[0])) self.n_device = len(device_ids) self.seen = 0 _check_balance(self.device_ids) if len(self.device_ids) == 1: self.module.cuda(device_ids[0])
def __init__(self, module, device_ids=None, output_device=None, dim=0): super(DataParallelImbalance, self).__init__(module, device_ids, output_device, dim) if not torch.cuda.is_available(): self.module = module self.device_ids = [] return if device_ids is None: device_ids = list(range(torch.cuda.device_count())) if output_device is None: output_device = device_ids[0] if not all(t.is_cuda and t.device.index == device_ids[0] for t in chain(module.parameters(), module.buffers())): raise RuntimeError("module must have its parameters and buffers " "on device %d (device_ids[0])" % device_ids[0]) self.dim = dim self.module = module self.device_ids = list( map(lambda x: _get_device_index(x, True), device_ids)) self.output_device = _get_device_index(output_device, True) if len(self.device_ids) == 1: self.module.cuda(device_ids[0])
def __init__(self, module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25, check_reduction=False, randk=1, seed=2147483647): super(RandomKSparsifiedDDP, self).__init__() torch.manual_seed(seed) # Use all devices by default if device_ids is None: device_ids = list(range(torch.cuda.device_count())) if len(device_ids) > 1: raise RuntimeError( "This module only supports Multi-Process Single-GPU mode.") if output_device is None: output_device = device_ids[0] if process_group is None: self.process_group = _get_default_group() else: self.process_group = process_group self.dim = dim self.module = module self.device_ids = list( map(lambda x: _get_device_index(x, True), device_ids)) self.output_device = _get_device_index(output_device, True) self.broadcast_buffers = broadcast_buffers self.check_reduction = check_reduction self.randk = randk self.masks = {} self.global_step = 0 MB = 1024 * 1024 # used for intra-node param sync and inter-node sync as well self.broadcast_bucket_size = 250 * MB # reduction bucket size self.bucket_bytes_cap = bucket_cap_mb * MB # Sync params and buffers module_states = list(self.module.state_dict().values()) if len(module_states) > 0: self._dist_broadcast_coalesced(module_states, self.broadcast_bucket_size) self._ddp_init_helper()
def data_parallel( module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None, non_scatter_kwargs=None, ): r"""Evaluates module(input) in parallel across the GPUs given in device_ids. This is the functional version of the DataParallel module. Args: module (Module): the module to evaluate in parallel inputs (Tensor): inputs to the module device_ids (list of int or torch.device): GPU ids on which to replicate module output_device (list of int or torch.device): GPU location of the output Use -1 to indicate the CPU. (default: device_ids[0]) Returns: a Tensor containing the result of module(input) located on output_device """ if not isinstance(inputs, tuple): inputs = (inputs,) if device_ids is None: device_ids = list(range(torch.cuda.device_count())) if output_device is None: output_device = device_ids[0] device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) output_device = _get_device_index(output_device, True) src_device_obj = torch.device("cuda:{}".format(device_ids[0])) for tensor in chain(module.parameters(), module.buffers()): if tensor.device != src_device_obj: raise RuntimeError( "module must have its parameters and buffers " "on device {} (device_ids[0]) but found one of " "them on device: {}".format(src_device_obj, tensor.device) ) inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) if len(device_ids) == 1: return module(*inputs[0], **module_kwargs[0]) used_device_ids = device_ids[: len(inputs)] replicas = replicate(module, used_device_ids) outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) return gather(outputs, output_device, dim)
def __init__(self, module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25, check_reduction=False): super(DistributedDataParallel, self).__init__() # Use all devices by default if device_ids is None: device_ids = list(range(torch.cuda.device_count())) if output_device is None: output_device = device_ids[0] if process_group is None: self.process_group = _get_default_group() else: self.process_group = process_group self.dim = dim self.module = module self.device_ids = list( map(lambda x: _get_device_index(x, True), device_ids)) self.output_device = _get_device_index(output_device, True) self.broadcast_buffers = broadcast_buffers if check_reduction: # This argument is no longer used since the reducer # will ensure reduction completes even if some parameters # do not receive gradients. pass MB = 1024 * 1024 # used for intra-node param sync and inter-node sync as well self.broadcast_bucket_size = int(250 * MB) # reduction bucket size self.bucket_bytes_cap = int(bucket_cap_mb * MB) # Sync params and buffers module_states = list(self.module.state_dict().values()) if len(module_states) > 0: self._dist_broadcast_coalesced(module_states, self.broadcast_bucket_size) self._ddp_init_helper()
def __init__( self, module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25, check_reduction=False, find_unused_parameters=False, ): super(DistributedDataParallelPythonBuckets, self).__init__() # Use all devices by default if device_ids is None: device_ids = list(range(torch.cuda.device_count())) if output_device is None: output_device = device_ids[0] if process_group is None: self.process_group = _get_default_group() else: self.process_group = process_group self.dim = dim self.module = module self.device_ids = list( map(lambda x: _get_device_index(x, True), device_ids)) self.output_device = _get_device_index(output_device, True) self.broadcast_buffers = broadcast_buffers self.check_reduction = check_reduction MB = 1024 * 1024 # used for intra-node param sync and inter-node sync as well self.broadcast_bucket_size = 250 * MB # reduction bucket size self.bucket_bytes_cap = bucket_cap_mb * MB # Sync params and buffers module_states = list(self.module.state_dict().values()) if len(module_states) > 0: self._dist_broadcast_coalesced(module_states, self.broadcast_bucket_size) self._ddp_init_helper()
def __init__(self, module, device_id=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25, check_reduction=False, sparse_ratio=0.1, sparse_threshold=1024, mem_decay=1.0): super(DistributedDataParallel, self).__init__() self.module = module if device_id is None: raise RuntimeError("device_id cannot be None") if output_device is None: output_device = device_id if process_group is None: self.process_group = _get_default_group() else: self.process_group = process_group self.dim = dim self.module = module self.device_id = _get_device_index(device_id, True) self.output_device = _get_device_index(output_device, True) self.broadcast_buffers = broadcast_buffers self.check_reduction = check_reduction self.sparse_ratio = sparse_ratio self.sparse_threshold = sparse_threshold self.mem_decay = mem_decay MB = 1024 * 1024 # used for intra-node param sync and inter-node sync as well self.broadcast_bucket_size = 250 * MB module_states = list(self.module.state_dict().values()) if len(module_states) > 0: self._dist_broadcast_coalesced(module_states, self.broadcast_bucket_size) self._ddp_init_helper()
def _to_device_index(devices): if not devices: raise RuntimeError("Cannot replicate using an empty device list.") if isinstance(devices, list) and isinstance(devices[0], list): device_ids = [] seen = set() for i, replica_devs in enumerate(devices): assert len(replica_devs) == len(devices[0]), ( "Cannot replicate to unidentical number of devices, but got " "device list {} and {} for replica {} and {}.").format( devices[0], devices[i], 0, i) assert len(seen.intersection(replica_devs)) == 0, ( "Devices {} are shared by multiple replicas.").format( seen.intersection(replica_devs)) seen.update(replica_devs) device_ids.append(_to_device_index(replica_devs)) return device_ids else: assert len(devices) == len( set(devices)), ("Duplicated device ids {}.").format(devices) return list(map(lambda x: _get_device_index(x, True), devices))
def _check_balance(device_ids): imbalance_warn = """ There is an imbalance between your GPUs. You may want to exclude GPU {} which has less than 75% of the memory or cores of GPU {}. You can do so by setting the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES environment variable.""" device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) #print("_check_balance device_ids:", device_ids) dev_props = [torch.cuda.get_device_properties(i) for i in device_ids] #print("_check_balance dev_props:", dev_props) def warn_imbalance(get_prop): values = [get_prop(props) for props in dev_props] min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1)) max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1)) if min_val / max_val < 0.75: warnings.warn( imbalance_warn.format(device_ids[min_pos], device_ids[max_pos])) return True return False #print("warn_imbalance total_memory:", warn_imbalance(lambda props: props.total_memory), \ # "warn_imbalance multi_processor_count:", warn_imbalance(lambda props: props.multi_processor_count), \ # "multi_processor_count:", props.multi_processor_count) if warn_imbalance(lambda props: props.total_memory): return if warn_imbalance(lambda props: props.multi_processor_count): return
def replicate_module(self, module: torch.nn.Module, devices: List[int]) -> List[torch.nn.Module]: assert self.n_mask_samples % len(devices) == 0 copies = replicate(module, devices) def walk(module: torch.nn.Module, copy: torch.nn.Module): module_map = {id(module): copy} for name, ref in module._modules.items(): module_map.update(walk(ref, getattr(copy, name))) return module_map devices = [_get_device_index(d) for d in devices] # Copy the custom parameters all_params = [p.get() for p in self.pointer_values] if (not self.masking_enabled) or (not self.training): scattered = _broadcast_coalesced_reshape(all_params, devices) else: # Here is more complicated, because there might be non-masked parameters which has to be handled in the # usual way masked_indices = [ i for i, n in enumerate(self.param_names) if self.is_masked(n) ] simple_indices = [ i for i, n in enumerate(self.param_names) if not self.is_masked(n) ] masked_params = scatter([all_params[i] for i in masked_indices], devices) simple_params = _broadcast_coalesced_reshape( [all_params[i] for i in simple_indices], devices) scattered = [[None] * len(all_params) for _ in devices] for d in range(len(devices)): for mi, mp in zip(masked_indices, masked_params[d]): scattered[d][mi] = mp for si, sp in zip(simple_indices, simple_params[d]): scattered[d][si] = sp for i, c in enumerate(copies): device_map = walk(module, c) for j, p in enumerate(self.pointer_values): setattr(device_map[id(p.parent)], p.name, scattered[i][j]) self.update_rnn_params(c) return copies
def forward(ctx, target_device, dim, *inputs): assert all(map(lambda i: i.is_cuda, inputs)) target_device = _get_device_index(target_device, True) ctx.target_device = target_device ctx.dim = dim ctx.input_gpus = tuple(map(lambda i: i.get_device(), inputs)) if all(t.dim() == 0 for t in inputs) and dim == 0: inputs = tuple(t.view(1) for t in inputs) warnings.warn('Was asked to gather along dimension 0, but all ' 'input tensors were scalars; will instead unsqueeze ' 'and return a vector.') ctx.unsqueezed_scalar = True else: ctx.unsqueezed_scalar = False ctx.input_sizes = tuple(map(lambda i: i.size(ctx.dim), inputs)) return comm.gather(inputs, ctx.dim, ctx.target_device)
def forward(ctx, target_gpus, chunk_sizes, dim, input): target_gpus = list(map(lambda x: _get_device_index(x, True), target_gpus)) ctx.dim = dim ctx.input_device = input.get_device() if input.is_cuda else -1 streams = None if ctx.input_device == -1: # Perform CPU to GPU copies in a background stream streams = [_get_stream(device) for device in target_gpus] outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams) # Synchronize with the copy stream if streams is not None: for i, output in enumerate(outputs): with torch.cuda.device(target_gpus[i]): main_stream = torch.cuda.current_stream() main_stream.wait_stream(streams[i]) output.record_stream(main_stream) return outputs
def forward(ctx, target_gpus, *inputs): if not all(input.is_cuda for input in inputs): raise TypeError('Broadcast function not implemented for CPU tensors') target_gpus = list(map(lambda x: _get_device_index(x, True), target_gpus)) ctx.target_gpus = target_gpus if len(inputs) == 0: return tuple() ctx.num_inputs = len(inputs) ctx.input_device = inputs[0].get_device() outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus) non_differentiables = [] for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]): if not input_requires_grad: for output in outputs: non_differentiables.append(output[idx]) ctx.mark_non_differentiable(*non_differentiables) return tuple([t for tensors in outputs for t in tensors])
def __init__(self, module, device_ids, dim=0, broadcast_buffers=False, find_unused_parameters=False, **kwargs): super().__init__() assert len(device_ids) == 1, ( 'Currently, DistributedDataParallelWrapper only supports one' 'single CUDA device for each process.' f'The length of device_ids must be 1, but got {len(device_ids)}.') self.module = module self.dim = dim self.to_ddp(device_ids=device_ids, dim=dim, broadcast_buffers=broadcast_buffers, find_unused_parameters=find_unused_parameters, **kwargs) self.output_device = _get_device_index(device_ids[0], True)
def __init__(self, module, split_size, device_ids=None, output_device=None): """ This method instantiates a ModelParallel Module from a module instance passed by the user. The model must have a single input (forward(x) type of signature for the forward method) otherwise an error is returned. An example is here: .. code-block:: python from eisen.utils import ModelParallel from eisen.models.segmentation import UNet model = ModelParallel( module=UNet(input_channels=1, output_channels=1), split_size=2, device_ids=[0, 1, 2, 3], output_device=0 ) :param module: an instance of the model that should be parallelized :type module: torch.nn.Module :param split_size: split size for pipelined execution :type split_size: int :param device_ids: list of int or torch devices indicating GPUs to use :type device_ids: list :param output_device: int or torch device indicating output devices :type output_device: int or torch device """ super(ModelParallel, self).__init__() module_argument_list = inspect.getfullargspec(module.forward)[0] if len(module_argument_list) > 2: raise NotImplementedError( 'Support for modules with more than one input is not yet implemented.' ) self.first_run = True self.split_size = split_size if not torch.cuda.is_available(): self.module = module self.device_ids = [] self.first_run = False return if device_ids is None: device_ids = list(range(torch.cuda.device_count())) if output_device is None: output_device = device_ids[0] self.module = module self.device_ids = list( map(lambda x: _get_device_index(x, True), device_ids)) self.output_device = _get_device_index(output_device, True) if len(self.device_ids) == 1: self.first_run = False self.module.cuda(device_ids[0])
def replicate(network, devices, detach=False): if not _replicatable_module(network): raise RuntimeError("Cannot replicate network where python modules are " "childrens of ScriptModule") devices = list(map(lambda x: _get_device_index(x, True), devices)) num_replicas = len(devices) params = list(network.parameters()) param_indices = {param: idx for idx, param in enumerate(params)} param_copies = _broadcast_coalesced_reshape(params, devices, detach) buffers = list(network.buffers()) buffers_rg = [] buffers_not_rg = [] for buf in buffers: if buf.requires_grad and not detach: buffers_rg.append(buf) else: buffers_not_rg.append(buf) buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)} buffer_indices_not_rg = { buf: idx for idx, buf in enumerate(buffers_not_rg) } buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach) buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True) modules = list(network.modules()) module_copies = [[] for device in devices] module_indices = {} scriptmodule_skip_attr = { "_parameters", "_buffers", "_modules", "forward", "_c" } for i, module in enumerate(modules): module_indices[module] = i for j in range(num_replicas): if _is_script_module(module): # we have to initialize ScriptModule properly so that # it works with pybind11 def init_fn(script_module): # Don't do anything here, we'll initialize the ScriptModule below return replica = torch.jit.RecursiveScriptModule._construct( module._c._replicate_for_data_parallel(), init_fn) else: replica = module._replicate_for_data_parallel() module_copies[j].append(replica) for i, module in enumerate(modules): for key, child in module._modules.items(): if child is None: for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = None else: module_idx = module_indices[child] for j in range(num_replicas): replica = module_copies[j][i] setattr(replica, key, module_copies[j][module_idx]) for key, param in module._parameters.items(): if param is None: for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = None else: param_idx = param_indices[param] for j in range(num_replicas): replica = module_copies[j][i] param = param_copies[j][param_idx] setattr( replica, key, Parameter(param, requires_grad=param.requires_grad)) # TODO: We need to manually set _parameters with a bare # non-parameter Tensor, otherwise gradients don't # accumulate in the original parameters when you call # backwards() on the DataParallel module. replica._parameters[key] = param for key, buf in module._buffers.items(): if buf is None: for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = None else: if buf.requires_grad and not detach: buffer_copies = buffer_copies_rg buffer_idx = buffer_indices_rg[buf] else: buffer_copies = buffer_copies_not_rg buffer_idx = buffer_indices_not_rg[buf] for j in range(num_replicas): replica = module_copies[j][i] setattr(replica, key, buffer_copies[j][buffer_idx]) return [module_copies[j][0] for j in range(num_replicas)]
def replicate(network, devices, detach=False): if not _replicatable_module(network): raise RuntimeError("Cannot replicate network where python modules are " "childrens of ScriptModule") devices = list(map(lambda x: _get_device_index(x, True), devices)) num_replicas = len(devices) params = list(network.parameters()) param_indices = {param: idx for idx, param in enumerate(params)} param_copies = _broadcast_coalesced_reshape(params, devices, detach) buffers = list(network.buffers()) buffers_rg = [] buffers_not_rg = [] for buf in buffers: if buf.requires_grad and not detach: buffers_rg.append(buf) else: buffers_not_rg.append(buf) buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)} buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)} buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach) buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True) modules = list(network.modules()) module_copies = [[] for device in devices] module_indices = {} scriptmodule_skip_attr = {"_parameters", "_buffers", "_modules", "forward", "_c"} for i, module in enumerate(modules): module_indices[module] = i for j in range(num_replicas): replica = module._replicate_for_data_parallel() # This is a temporary fix for DDP. DDP needs to access the # replicated model parameters. It used to do so through # `mode.parameters()`. The fix added in #33907 for DP stops the # `parameters()` API from exposing the replicated parameters. # Hence, we add a `_former_parameters` dict here to support DDP. replica._former_parameters = OrderedDict() module_copies[j].append(replica) for i, module in enumerate(modules): for key, child in module._modules.items(): if child is None: for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = None else: module_idx = module_indices[child] for j in range(num_replicas): replica = module_copies[j][i] setattr(replica, key, module_copies[j][module_idx]) for key, param in module._parameters.items(): if param is None: for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = None else: param_idx = param_indices[param] for j in range(num_replicas): replica = module_copies[j][i] param = param_copies[j][param_idx] # parameters in replicas are no longer leaves, # so setattr them as non-parameter attributes setattr(replica, key, param) # expose the parameter for DDP replica._former_parameters[key] = param for key, buf in module._buffers.items(): if buf is None: for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = None else: if buf.requires_grad and not detach: buffer_copies = buffer_copies_rg buffer_idx = buffer_indices_rg[buf] else: buffer_copies = buffer_copies_not_rg buffer_idx = buffer_indices_not_rg[buf] for j in range(num_replicas): replica = module_copies[j][i] setattr(replica, key, buffer_copies[j][buffer_idx]) return [module_copies[j][0] for j in range(num_replicas)]
def replicate(network, devices, detach=False): if not _replicatable_module(network): raise RuntimeError("Cannot replicate network where python modules are " "childrens of ScriptModule") devices = list(map(lambda x: _get_device_index(x, True), devices)) num_replicas = len(devices) params = list(network.parameters()) param_indices = {param: idx for idx, param in enumerate(params)} param_copies = _broadcast_coalesced_reshape(params, devices, detach) buffers = list(network.buffers()) buffers_rg = [] buffers_not_rg = [] for buf in buffers: if buf.requires_grad and not detach: buffers_rg.append(buf) else: buffers_not_rg.append(buf) buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)} buffer_indices_not_rg = { buf: idx for idx, buf in enumerate(buffers_not_rg) } buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach) buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True) modules = list(network.modules()) module_copies = [[] for device in devices] module_indices = {} scriptmodule_skip_attr = { "_parameters", "_buffers", "_modules", "forward", "_c" } for i, module in enumerate(modules): module_indices[module] = i for j in range(num_replicas): if _is_script_module(module): # we have to initialize ScriptModule properly so that # it works with pybind11 def init_fn(script_module): # Don't do anything here, we'll initialize the ScriptModule below return replica = torch.jit.RecursiveScriptModule._construct( module._c._replicate_for_data_parallel(), init_fn) else: replica = module._replicate_for_data_parallel() # This is a temporary fix for DDP. DDP needs to access the # replicated model parameters. It used to do so through # `mode.parameters()`. The fix added in #33907 for DP stops the # `parameters()` API from exposing the replicated parameters. # Hence, we add a `_former_parameters` dict here to support DDP. replica._former_parameters = OrderedDict() module_copies[j].append(replica) for i, module in enumerate(modules): for key, child in module._modules.items(): if child is None: for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = None else: module_idx = module_indices[child] for j in range(num_replicas): replica = module_copies[j][i] setattr(replica, key, module_copies[j][module_idx]) for key, param in module._parameters.items(): if param is None: for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = None else: param_idx = param_indices[param] for j in range(num_replicas): replica = module_copies[j][i] param = param_copies[j][param_idx] # parameters in replicas are no longer leaves, so remove them from _parameters # and setattr them as non-parameter attributes # scripted modules don't allow deleting parameters, but also don't complain # on assigning non-Parameter type if (not _is_script_module(replica)): del replica._parameters[key] setattr(replica, key, param) # expose the parameter for DDP replica._former_parameters[key] = param for key, buf in module._buffers.items(): if buf is None: for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = None else: if buf.requires_grad and not detach: buffer_copies = buffer_copies_rg buffer_idx = buffer_indices_rg[buf] else: buffer_copies = buffer_copies_not_rg buffer_idx = buffer_indices_not_rg[buf] for j in range(num_replicas): replica = module_copies[j][i] setattr(replica, key, buffer_copies[j][buffer_idx]) return [module_copies[j][0] for j in range(num_replicas)]
def replicate(network, devices, detach=False): if not _replicatable_module(network): raise RuntimeError("Cannot replicate network where python modules are " "childrens of ScriptModule") devices = list(map(lambda x: _get_device_index(x, True), devices)) num_replicas = len(devices) params = list(network.parameters()) param_indices = {param: idx for idx, param in enumerate(params)} param_copies = _broadcast_coalesced_reshape(params, devices, detach) buffers = list(network.buffers()) buffers_rg = [] buffers_not_rg = [] for buf in buffers: if buf.requires_grad and not detach: buffers_rg.append(buf) else: buffers_not_rg.append(buf) buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)} buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)} buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach) buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True) modules = list(network.modules()) module_copies = [[] for device in devices] module_indices = {} scriptmodule_skip_attr = {"_parameters", "_buffers", "_modules"} for i, module in enumerate(modules): module_indices[module] = i for j in range(num_replicas): if _is_script_module(module): # we have to initialize ScriptModule properly so that # it works with pybind11 replica = _init_script_module() keys = set(module.__dict__.keys()) - scriptmodule_skip_attr for key in keys: replica.__dict__[key] = module.__dict__[key] else: replica = module.__new__(type(module)) replica.__dict__ = module.__dict__.copy() replica._parameters = replica._parameters.copy() replica._buffers = replica._buffers.copy() replica._modules = replica._modules.copy() module_copies[j].append(replica) for i, module in enumerate(modules): for key, child in module._modules.items(): if child is None: for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = None else: module_idx = module_indices[child] for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = module_copies[j][module_idx] for key, param in module._parameters.items(): if param is None: for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = None else: param_idx = param_indices[param] for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = param_copies[j][param_idx] for key, buf in module._buffers.items(): if buf is None: for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = None else: if buf.requires_grad and not detach: buffer_copies = buffer_copies_rg buffer_idx = buffer_indices_rg[buf] else: buffer_copies = buffer_copies_not_rg buffer_idx = buffer_indices_not_rg[buf] for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = buffer_copies[j][buffer_idx] for j in range(num_replicas): _copy_scriptmodule_methods(modules, module_copies[j], module_indices) return [module_copies[j][0] for j in range(num_replicas)]
def criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): assert len(modules) == len(inputs) assert len(targets) == len(inputs) if kwargs_tup is not None: assert len(modules) == len(kwargs_tup) else: kwargs_tup = ({}, ) * len(modules) if devices is not None: assert len(modules) == len(devices) else: devices = [None] * len(modules) devices = list(map(lambda x: _get_device_index(x, True), devices)) lock = threading.Lock() results = {} grad_enabled = torch.is_grad_enabled() def _worker(i, module, input, target, kwargs, device=None): torch.set_grad_enabled(grad_enabled) if device is None: device = get_a_var(input).get_device() try: with torch.cuda.device(device): if not isinstance(input, (list, tuple)): input = (input, ) if not isinstance(target, (list, tuple)): target = (target, ) output = module(*input, *target, **kwargs) with lock: results[i] = output except Exception: with lock: results[i] = ExceptionWrapper( where="in replica {} on device {}".format(i, device)) if len(modules) > 1: threads = [ threading.Thread(target=_worker, args=(i, module, input, target, kwargs, device)) for i, ( module, input, target, kwargs, device) in enumerate(zip(modules, inputs, kwargs_tup, devices)) ] for thread in threads: thread.start() for thread in threads: thread.join() else: _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) outputs = [] for i in range(len(inputs)): output = results[i] if isinstance(output, ExceptionWrapper): output.reraise() outputs.append(output) return outputs
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): r"""Applies each `module` in :attr:`modules` in parallel on arguments contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) on each of :attr:`devices`. Args: modules (Module): modules to be parallelized inputs (tensor): inputs to the modules devices (list of int or torch.device): CUDA devices :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and :attr:`devices` (if given) should all have same length. Moreover, each element of :attr:`inputs` can either be a single object as the only argument to a module, or a collection of positional arguments. """ assert len(modules) == len(inputs) if kwargs_tup is not None: assert len(modules) == len(kwargs_tup) else: kwargs_tup = ({}, ) * len(modules) if devices is not None: assert len(modules) == len(devices) else: devices = [None] * len(modules) devices = list(map(lambda x: _get_device_index(x, True), devices)) lock = threading.Lock() results = {} grad_enabled = torch.is_grad_enabled() def _worker(i, module, input, kwargs, device=None): torch.set_grad_enabled(grad_enabled) if device is None: device = get_a_var(input).get_device() try: with torch.cuda.device(device): # this also avoids accidental slicing of `input` if it is a Tensor #if not isinstance(input, (list, tuple)): # input = (input,) output = module(input) with lock: results[i] = output except Exception: with lock: results[i] = ExceptionWrapper( where="in replica {} on device {}".format(i, device)) if len(modules) > 1: threads = [ threading.Thread(target=_worker, args=(i, module, input, kwargs, device)) for i, ( module, input, kwargs, device) in enumerate(zip(modules, inputs, kwargs_tup, devices)) ] for thread in threads: thread.start() for thread in threads: thread.join() else: _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) outputs = [] for i in range(len(inputs)): output = results[i] if isinstance(output, ExceptionWrapper): output.reraise() outputs.append(output) return outputs
def replicate(network, devices, detach=False): from ._functions import Broadcast if not _replicatable_module(network): raise RuntimeError("Cannot replicate network where python modules are " "childrens of ScriptModule") devices = list(map(lambda x: _get_device_index(x, True), devices)) num_replicas = len(devices) params = list(network.parameters()) param_indices = {param: idx for idx, param in enumerate(params)} param_copies = Broadcast.apply(devices, *params) if len(params) > 0: param_copies = [ param_copies[i:i + len(params)] for i in range(0, len(param_copies), len(params)) ] buffers = list(network.buffers()) buffer_indices = {buf: idx for idx, buf in enumerate(buffers)} buffer_copies = comm.broadcast_coalesced(buffers, devices) modules = list(network.modules()) module_copies = [[] for device in devices] module_indices = {} scriptmodule_skip_attr = {"_parameters", "_buffers", "_modules"} for i, module in enumerate(modules): module_indices[module] = i for j in range(num_replicas): if _is_script_module(module): # we have to initialize ScriptModule properly so that # it works with pybind11 replica = _init_script_module() keys = set(module.__dict__.keys()) - scriptmodule_skip_attr for key in keys: replica.__dict__[key] = module.__dict__[key] else: replica = module.__new__(type(module)) replica.__dict__ = module.__dict__.copy() replica._parameters = replica._parameters.copy() replica._buffers = replica._buffers.copy() replica._modules = replica._modules.copy() module_copies[j].append(replica) for i, module in enumerate(modules): for key, child in module._modules.items(): if child is None: for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = None else: module_idx = module_indices[child] for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = module_copies[j][module_idx] for key, param in module._parameters.items(): if param is None: for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = None else: param_idx = param_indices[param] for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = param_copies[j][param_idx].detach() \ if detach else param_copies[j][param_idx] for key, buf in module._buffers.items(): if buf is None: for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = None else: buffer_idx = buffer_indices[buf] for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = buffer_copies[j][buffer_idx] for j in range(num_replicas): _copy_scriptmodule_methods(modules, module_copies[j], module_indices) return [module_copies[j][0] for j in range(num_replicas)]
def replicate(network, devices, detach=False): if not _replicatable_module(network): raise RuntimeError( "Cannot replicate network where python modules are " "childrens of ScriptModule" ) devices = list(map(lambda x: _get_device_index(x, True), devices)) num_replicas = len(devices) params = list(network.parameters()) param_indices = {param: idx for idx, param in enumerate(params)} param_copies = _broadcast_coalesced_reshape(params, devices, detach) buffers = list(network.buffers()) buffers_rg = [] buffers_not_rg = [] for buf in buffers: if buf.requires_grad and not detach: buffers_rg.append(buf) else: buffers_not_rg.append(buf) buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)} buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)} buffer_copies_rg = _broadcast_coalesced_reshape( buffers_rg, devices, detach=detach ) buffer_copies_not_rg = _broadcast_coalesced_reshape( buffers_not_rg, devices, detach=True ) modules = list(network.modules()) module_copies = [[] for _ in devices] module_indices = {} scriptmodule_skip_attr = { "_parameters", "_buffers", "_modules", "forward", "_c", } for i, module in enumerate(modules): module_indices[module] = i for j in range(num_replicas): if _is_script_module(module): # we have to initialize ScriptModule properly so that # it works with pybind11 replica = _init_script_module() attribute_names = set( entry[0] for entry in module._c._get_attributes() ) keys = ( set(module.__dict__.keys()) - scriptmodule_skip_attr - attribute_names ) for key in keys: if not _is_script_method(module.__dict__[key]): replica.__dict__[key] = module.__dict__[key] for name, the_type, value in module._c._get_attributes(): if name in module._buffers.keys(): continue replica._c._register_attribute(name, the_type, value) else: replica = module.__new__(type(module)) replica.__dict__ = module.__dict__.copy() replica._parameters = replica._parameters.copy() replica._buffers = replica._buffers.copy() replica._modules = replica._modules.copy() module_copies[j].append(replica) for i, module in enumerate(modules): for key, child in module._modules.items(): if child is None: for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = None else: module_idx = module_indices[child] for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = module_copies[j][module_idx] for key, param in module._parameters.items(): if param is None: for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = None else: param_idx = param_indices.get(param, None) if param_idx is None: continue for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = param_copies[j][param_idx] for key, buf in module._buffers.items(): if buf is None: for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = None else: if buf.requires_grad and not detach: buffer_copies = buffer_copies_rg buffer_idx = buffer_indices_rg.get(buf, None) else: buffer_copies = buffer_copies_not_rg buffer_idx = buffer_indices_not_rg.get(buf, None) if buffer_idx is None: continue for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = buffer_copies[j][buffer_idx] for j in range(num_replicas): _copy_scriptmodule_methods(modules, module_copies[j], module_indices) replicas = [module_copies[j][0] for j in range(num_replicas)] for model_replica in replicas: for _, submodule in model_replica.named_modules(): if hasattr(submodule, "on_replicate") and callable( submodule.on_replicate ): submodule.on_replicate() return replicas
def __init__(self, module, device_ids=None, broadcast_buffers=True, compression=Compression.none ): super(DistributedDataParallel, self).__init__() assert device_ids and len(device_ids) == 1, ( "DistributedDataParallel device_ids contain exactlyone entry," " but got {}.").format(device_ids) self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) self.module = module self.broadcast_buffers = broadcast_buffers self.require_forward_param_sync = broadcast_buffers self._handles = {} self._grad_accs = [] self._requires_update = set() self._num_grads = 1 self.modules_buffers = [list(self.module.buffers())] self._compression = compression self._enable_async = False self._require_backward_grad_sync = True named_parameters = self.module.named_parameters() named_parameters = list(named_parameters) if len(named_parameters) > 0: if isinstance(named_parameters[0][1], torch.Tensor): if any([not isinstance(p, torch.Tensor) for name, p in named_parameters]): raise ValueError('named_parameters should consistently be a sequence of ' 'tuples (name, torch.Tensor)') self._is_tensor_instance = True # there is an issue when using torch.Tensor as key, so use its hash instead # https://github.com/pytorch/pytorch/issues/7733 self._parameter_names = {v.__hash__(): k for k, v in sorted(named_parameters)} self._tensor_list = [tensor for name, tensor in named_parameters] else: self._is_tensor_instance = False self._parameter_names = {v: k for k, v in sorted(named_parameters)} else: self._is_tensor_instance = False self._parameter_names = {v: 'push_pull.noname.%s' % i for param_group in self.param_groups for i, v in enumerate(param_group['params'])} if size() > 1: self._register_hooks() named_params = self.module.named_parameters() self._num_grads = sum(p.requires_grad for _, p in named_params) byteps_torch_set_num_grads(self._num_grads) # declare tensors for name in sorted(self._parameter_names.values()): declare("Gradient."+name) # We use two loops for load-balancing for name in sorted(self._parameter_names.values()): declare("Parameter."+name) # broadcast model state module_states = list(self.module.state_dict().values()) if len(module_states) > 0: bps.torch.broadcast_parameters(self.module.state_dict(), root_rank=0)
def replicate(network, devices, detach=False): from ._functions import Broadcast devices = list(map(lambda x: _get_device_index(x, True), devices)) num_replicas = len(devices) params = list(network.parameters()) param_indices = {param: idx for idx, param in enumerate(params)} param_copies = Broadcast.apply(devices, *params) if len(params) > 0: param_copies = [ param_copies[i:i + len(params)] for i in range(0, len(param_copies), len(params)) ] buffers = list(network.buffers()) buffer_indices = {buf: idx for idx, buf in enumerate(buffers)} buffer_copies = comm.broadcast_coalesced(buffers, devices) modules = list(network.modules()) module_copies = [[] for device in devices] module_indices = {} for i, module in enumerate(modules): module_indices[module] = i for j in range(num_replicas): replica = module.__new__(type(module)) replica.__dict__ = module.__dict__.copy() replica._parameters = replica._parameters.copy() replica._buffers = replica._buffers.copy() replica._modules = replica._modules.copy() module_copies[j].append(replica) for i, module in enumerate(modules): for key, child in module._modules.items(): if child is None: for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = None else: module_idx = module_indices[child] for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = module_copies[j][module_idx] for key, param in module._parameters.items(): if param is None: for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = None else: param_idx = param_indices[param] for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = param_copies[j][param_idx].detach() \ if detach else param_copies[j][param_idx] for key, buf in module._buffers.items(): if buf is None: for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = None else: buffer_idx = buffer_indices[buf] for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = buffer_copies[j][buffer_idx] return [module_copies[j][0] for j in range(num_replicas)]
def parallel_apply_predict(modules, inputs, kwargs_tup=None, devices=None): """Applies each `module` predict method in `modules` in parallel on arguments contained in `inputs` (positional) and `kwargs_tup` (keyword) on each of `devices`. Args: modules: modules to be parallelized. inputs: inputs to the modules. devices: CUDA devices. """ assert len(modules) == len(inputs) if kwargs_tup is not None: assert len(modules) == len(kwargs_tup) else: kwargs_tup = ({},) * len(modules) if devices is not None: assert len(modules) == len(devices) else: devices = [None] * len(modules) devices = list(map(lambda x: _get_device_index(x, True), devices)) lock = threading.Lock() results = {} grad_enabled = torch.is_grad_enabled() def _worker(i, module, input, kwargs, device=None): torch.set_grad_enabled(grad_enabled) if device is None: device = get_a_var(input).get_device() try: with torch.cuda.device(device): # this also avoids accidental slicing of `input` if it is a Tensor if not isinstance(input, (list, tuple)): input = (input,) output = module.predict(*input, **kwargs) with lock: results[i] = output except Exception: with lock: results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device)) if len(modules) > 1: threads = [ threading.Thread(target=_worker, args=(i, module, input, kwargs, device)) for i, (module, input, kwargs, device) in enumerate(zip(modules, inputs, kwargs_tup, devices)) ] for thread in threads: thread.start() for thread in threads: thread.join() else: _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) outputs = [] for i in range(len(inputs)): output = results[i] if isinstance(output, ExceptionWrapper): output.reraise() outputs.append(output) return outputs
def __init__(self, module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25): super(DistributedDataParallel, self).__init__() # Use all devices by default if device_ids is None: device_ids = list(range(torch.cuda.device_count())) if output_device is None: output_device = device_ids[0] if process_group is None: self.process_group = dist.get_default_group() else: self.process_group = process_group self.dim = dim self.module = module self.device_ids = list( map(lambda x: _get_device_index(x, True), device_ids)) self.output_device = _get_device_index(output_device, True) self.broadcast_buffers = broadcast_buffers MB = 1024 * 1024 # used for intra-node param sync and inter-node sync as well self.broadcast_bucket_size = 250 * MB # Sync params and buffers module_states = list(self.module.state_dict().values()) if len(module_states) > 0: self._dist_broadcast_coalesced(module_states, self.broadcast_bucket_size) if len(device_ids) > 1: # TODO: we don't need to replicate params in here. they're always going to # be broadcasted using larger blocks in broadcast_coalesced, so it might be # better to not pollute the caches with these small blocks self._module_copies = replicate(self.module, self.device_ids, detach=True) self._module_copies[0] = self.module for module_copy in self._module_copies[1:]: for param, copy_param in zip(self.module.parameters(), module_copy.parameters()): copy_param.requires_grad = param.requires_grad else: self._module_copies = [self.module] self.modules_params_data = [[] for _ in range(len(self.device_ids))] self.modules_buffers_data = [[] for _ in range(len(self.device_ids))] for dev_idx, module in enumerate(self._module_copies): self.modules_params_data[dev_idx] = [ p.data for p in module.parameters() ] self.modules_buffers_data[dev_idx] = [ b.data for b in module.buffers() ] bucket_bytes_cap = bucket_cap_mb * MB # This is a triply-nested list where the "dimensions" are: devices, buckets, bucket_elems param_buckets = [] # Split the parameters into buckets and by types as well param_buckets = [ list(_take_tensors(m.parameters(), bucket_bytes_cap)) for m in self._module_copies ] self.bucket_sizes = [] self.bucket_map = {} # We transpose param_buckets, so the loop is over buckets. # param_buckets_tuple is a doubly-nested list with "dims": devices, bucket_elems for bucket_idx, param_buckets_tuple in enumerate(zip(*param_buckets)): self.bucket_sizes.append(0) # Now, we transpose again, so we iterate over bucket_elems, but getting tuples # of params from each device. for param_tuple in zip(*param_buckets_tuple): if not param_tuple[0].requires_grad: continue for p in param_tuple: self.bucket_map[p] = (bucket_idx, self.bucket_sizes[bucket_idx]) self.bucket_sizes[bucket_idx] += 1 self.buckets = [[[None for _ in range(self.bucket_sizes[i])] for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))] # The number of params ready in each bucket self.buckets_ready_size = [[0 for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))] # coalesced bucket for only device 0 self.buckets_coalesced = [[] for _ in range(len(self.bucket_sizes))] # We will always reduce the bucket following the reverse order # that is, alway reduces following the order of: n - 1, n - 2, ..., 0 self.next_bucket = len(self.bucket_sizes) - 1 self.ready_buckets_not_reduced = set() self.reduction_works = [None for _ in range(len(self.bucket_sizes))] self.devs_ready = [0 for _ in range(len(self.bucket_sizes))] self._register_grad_hooks()
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: no-cover r"""Applies each `module` in :attr:`modules` in parallel on arguments contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) on each of :attr:`devices`. Args: modules (Module): modules to be parallelized inputs (tensor): inputs to the modules devices (list of int or torch.device): CUDA devices :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and :attr:`devices` (if given) should all have same length. Moreover, each element of :attr:`inputs` can either be a single object as the only argument to a module, or a collection of positional arguments. """ assert len(modules) == len(inputs) if kwargs_tup is not None: assert len(modules) == len(kwargs_tup) else: kwargs_tup = ({}, ) * len(modules) if devices is not None: assert len(modules) == len(devices) else: devices = [None] * len(modules) devices = list(map(lambda x: _get_device_index(x, True), devices)) lock = threading.Lock() results = {} grad_enabled = torch.is_grad_enabled() def _worker(i, module, input, kwargs, device=None): torch.set_grad_enabled(grad_enabled) if device is None: device = get_a_var(input).get_device() try: with torch.cuda.device(device): # this also avoids accidental slicing of `input` if it is a Tensor if not isinstance(input, (list, tuple)): input = (input, ) # --------------- # CHANGE if module.training: output = module.training_step(*input, **kwargs) elif module.testing: output = module.test_step(*input, **kwargs) else: output = module.validation_step(*input, **kwargs) if module.use_dp or module.use_ddp2: auto_squeeze_dim_zeros(output) # --------------- with lock: results[i] = output except Exception as ex: with lock: results[i] = ex # TODO: fix hack (maybe not a hack) # make sure each module knows what training state it's in... # fixes weird bug where copies are out of sync root_m = modules[0] for m in modules[1:]: m.training = root_m.training m.testing = root_m.testing if len(modules) > 1: threads = [ threading.Thread(target=_worker, args=(i, module, input, kwargs, device)) for i, ( module, input, kwargs, device) in enumerate(zip(modules, inputs, kwargs_tup, devices)) ] for thread in threads: thread.start() for thread in threads: thread.join() else: _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) outputs = [] for i in range(len(inputs)): output = results[i] if isinstance(output, Exception): raise output outputs.append(output) return outputs
def __init__(self, module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25, find_unused_parameters=False, check_reduction=False): super(DistributedDataParallel, self).__init__() self.is_multi_device_module = len( {p.device for p in module.parameters()}) > 1 self.is_cuda = all( [p.device.type == 'cuda' for p in module.parameters()]) if not self.is_cuda or self.is_multi_device_module: assert not device_ids and not output_device, ( "DistributedDataParallel device_ids and output_device arguments " "only work with single-device CUDA modules, but got " "device_ids {}, output_device {}, and module parameters {}." ).format(device_ids, output_device, {p.device for p in module.parameters()}) self.device_ids = None self.output_device = None else: # Use all devices by default for single-device CUDA modules if device_ids is None: device_ids = list(range(torch.cuda.device_count())) self.device_ids = list( map(lambda x: _get_device_index(x, True), device_ids)) # if output_device is None: # output_device = device_ids[0] self.output_device = _get_device_index(output_device, True) if self.is_multi_device_module: assert self.is_cuda, ( "DistributedDataParallel with multi-device module only works " "with CUDA devices, but module parameters locate in {}." ).format({p.device for p in module.parameters()}) if process_group is None: self.process_group = _get_default_group() else: self.process_group = process_group self.dim = dim self.module = module self.broadcast_buffers = broadcast_buffers self.find_unused_parameters = find_unused_parameters self.require_backward_grad_sync = True self.require_forward_param_sync = True if check_reduction: # This argument is no longer used since the reducer # will ensure reduction completes even if some parameters # do not receive gradients. pass MB = 1024 * 1024 # used for intra-node param sync and inter-node sync as well self.broadcast_bucket_size = int(250 * MB) # reduction bucket size self.bucket_bytes_cap = int(bucket_cap_mb * MB) # Sync params and buffers module_states = list(self.module.state_dict().values()) if len(module_states) > 0: self._distributed_broadcast_coalesced(module_states, self.broadcast_bucket_size) self._ddp_init_helper()