示例#1
0
    def __init__(self, module, device_ids=None, output_device=None, dim=0):
        super(torch.nn.DataParallel, self).__init__()
        #super().__init__(module, device_ids, output_device, dim)

        if not torch.cuda.is_available():
            self.module = module
            self.device_ids = []
            return

        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        if output_device is None:
            output_device = device_ids[0]

        self.dim = dim
        self.module = module
        self.device_ids = list(
            map(lambda x: _get_device_index(x, True), device_ids))
        self.output_device = _get_device_index(output_device, True)
        self.src_device_obj = torch.device("cuda:{}".format(
            self.device_ids[0]))
        #print("module:", module, "dim:", dim, "device_ids:", self.device_ids, \
        #      "output_device:", self.output_device, "src_device_obj:", self.src_device_obj)

        _check_balance(self.device_ids)

        if len(self.device_ids) == 1:
            #print("len(self.device_ids):", len(self.device_ids))
            self.module.cuda(device_ids[0])
    def __init__(self, module, device_ids=None, output_device=None, dim=0):
        super(DataParallel, self).__init__()

        if not torch.cuda.is_available():
            self.module = module
            self.device_ids = []
            return

        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        if output_device is None:
            output_device = device_ids[0]

        self.dim = dim
        self.module = module
        self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
        self.output_device = _get_device_index(output_device, True)
        self.src_device_obj = torch.device("cuda:{}".format(self.device_ids[0]))

        self.n_device = len(device_ids)
        self.seen = 0

        _check_balance(self.device_ids)

        if len(self.device_ids) == 1:
            self.module.cuda(device_ids[0])
示例#3
0
    def __init__(self, module, device_ids=None, output_device=None, dim=0):
        super(DataParallelImbalance, self).__init__(module, device_ids,
                                                    output_device, dim)

        if not torch.cuda.is_available():
            self.module = module
            self.device_ids = []
            return

        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        if output_device is None:
            output_device = device_ids[0]

        if not all(t.is_cuda and t.device.index == device_ids[0]
                   for t in chain(module.parameters(), module.buffers())):
            raise RuntimeError("module must have its parameters and buffers "
                               "on device %d (device_ids[0])" % device_ids[0])

        self.dim = dim
        self.module = module
        self.device_ids = list(
            map(lambda x: _get_device_index(x, True), device_ids))
        self.output_device = _get_device_index(output_device, True)

        if len(self.device_ids) == 1:
            self.module.cuda(device_ids[0])
示例#4
0
    def __init__(self,
                 module,
                 device_ids=None,
                 output_device=None,
                 dim=0,
                 broadcast_buffers=True,
                 process_group=None,
                 bucket_cap_mb=25,
                 check_reduction=False,
                 randk=1,
                 seed=2147483647):

        super(RandomKSparsifiedDDP, self).__init__()

        torch.manual_seed(seed)
        # Use all devices by default
        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))

        if len(device_ids) > 1:
            raise RuntimeError(
                "This module only supports Multi-Process Single-GPU mode.")

        if output_device is None:
            output_device = device_ids[0]

        if process_group is None:
            self.process_group = _get_default_group()
        else:
            self.process_group = process_group

        self.dim = dim
        self.module = module
        self.device_ids = list(
            map(lambda x: _get_device_index(x, True), device_ids))
        self.output_device = _get_device_index(output_device, True)
        self.broadcast_buffers = broadcast_buffers
        self.check_reduction = check_reduction
        self.randk = randk
        self.masks = {}

        self.global_step = 0

        MB = 1024 * 1024

        # used for intra-node param sync and inter-node sync as well
        self.broadcast_bucket_size = 250 * MB

        # reduction bucket size
        self.bucket_bytes_cap = bucket_cap_mb * MB

        # Sync params and buffers
        module_states = list(self.module.state_dict().values())
        if len(module_states) > 0:
            self._dist_broadcast_coalesced(module_states,
                                           self.broadcast_bucket_size)

        self._ddp_init_helper()
示例#5
0
def data_parallel(
    module,
    inputs,
    device_ids=None,
    output_device=None,
    dim=0,
    module_kwargs=None,
    non_scatter_kwargs=None,
):
    r"""Evaluates module(input) in parallel across the GPUs given in device_ids.

    This is the functional version of the DataParallel module.

    Args:
        module (Module): the module to evaluate in parallel
        inputs (Tensor): inputs to the module
        device_ids (list of int or torch.device): GPU ids on which to replicate module
        output_device (list of int or torch.device): GPU location of the output
            Use -1 to indicate the CPU.
            (default: device_ids[0])
    Returns:
        a Tensor containing the result of module(input) located on
        output_device
    """
    if not isinstance(inputs, tuple):
        inputs = (inputs,)

    if device_ids is None:
        device_ids = list(range(torch.cuda.device_count()))

    if output_device is None:
        output_device = device_ids[0]

    device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
    output_device = _get_device_index(output_device, True)
    src_device_obj = torch.device("cuda:{}".format(device_ids[0]))

    for tensor in chain(module.parameters(), module.buffers()):
        if tensor.device != src_device_obj:
            raise RuntimeError(
                "module must have its parameters and buffers "
                "on device {} (device_ids[0]) but found one of "
                "them on device: {}".format(src_device_obj, tensor.device)
            )

    inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
    if len(device_ids) == 1:
        return module(*inputs[0], **module_kwargs[0])
    used_device_ids = device_ids[: len(inputs)]
    replicas = replicate(module, used_device_ids)
    outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
    return gather(outputs, output_device, dim)
示例#6
0
    def __init__(self,
                 module,
                 device_ids=None,
                 output_device=None,
                 dim=0,
                 broadcast_buffers=True,
                 process_group=None,
                 bucket_cap_mb=25,
                 check_reduction=False):

        super(DistributedDataParallel, self).__init__()

        # Use all devices by default
        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))

        if output_device is None:
            output_device = device_ids[0]

        if process_group is None:
            self.process_group = _get_default_group()
        else:
            self.process_group = process_group

        self.dim = dim
        self.module = module
        self.device_ids = list(
            map(lambda x: _get_device_index(x, True), device_ids))
        self.output_device = _get_device_index(output_device, True)
        self.broadcast_buffers = broadcast_buffers
        if check_reduction:
            # This argument is no longer used since the reducer
            # will ensure reduction completes even if some parameters
            # do not receive gradients.
            pass

        MB = 1024 * 1024

        # used for intra-node param sync and inter-node sync as well
        self.broadcast_bucket_size = int(250 * MB)

        # reduction bucket size
        self.bucket_bytes_cap = int(bucket_cap_mb * MB)

        # Sync params and buffers
        module_states = list(self.module.state_dict().values())
        if len(module_states) > 0:
            self._dist_broadcast_coalesced(module_states,
                                           self.broadcast_bucket_size)

        self._ddp_init_helper()
示例#7
0
    def __init__(
        self,
        module,
        device_ids=None,
        output_device=None,
        dim=0,
        broadcast_buffers=True,
        process_group=None,
        bucket_cap_mb=25,
        check_reduction=False,
        find_unused_parameters=False,
    ):

        super(DistributedDataParallelPythonBuckets, self).__init__()

        # Use all devices by default
        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))

        if output_device is None:
            output_device = device_ids[0]

        if process_group is None:
            self.process_group = _get_default_group()
        else:
            self.process_group = process_group

        self.dim = dim
        self.module = module
        self.device_ids = list(
            map(lambda x: _get_device_index(x, True), device_ids))
        self.output_device = _get_device_index(output_device, True)
        self.broadcast_buffers = broadcast_buffers
        self.check_reduction = check_reduction

        MB = 1024 * 1024

        # used for intra-node param sync and inter-node sync as well
        self.broadcast_bucket_size = 250 * MB

        # reduction bucket size
        self.bucket_bytes_cap = bucket_cap_mb * MB

        # Sync params and buffers
        module_states = list(self.module.state_dict().values())
        if len(module_states) > 0:
            self._dist_broadcast_coalesced(module_states,
                                           self.broadcast_bucket_size)

        self._ddp_init_helper()
    def __init__(self,
                 module,
                 device_id=None,
                 output_device=None,
                 dim=0,
                 broadcast_buffers=True,
                 process_group=None,
                 bucket_cap_mb=25,
                 check_reduction=False,
                 sparse_ratio=0.1,
                 sparse_threshold=1024,
                 mem_decay=1.0):

        super(DistributedDataParallel, self).__init__()
        self.module = module

        if device_id is None:
            raise RuntimeError("device_id cannot be None")

        if output_device is None:
            output_device = device_id

        if process_group is None:
            self.process_group = _get_default_group()
        else:
            self.process_group = process_group

        self.dim = dim
        self.module = module
        self.device_id = _get_device_index(device_id, True)
        self.output_device = _get_device_index(output_device, True)
        self.broadcast_buffers = broadcast_buffers
        self.check_reduction = check_reduction
        self.sparse_ratio = sparse_ratio
        self.sparse_threshold = sparse_threshold
        self.mem_decay = mem_decay

        MB = 1024 * 1024

        # used for intra-node param sync and inter-node sync as well
        self.broadcast_bucket_size = 250 * MB

        module_states = list(self.module.state_dict().values())
        if len(module_states) > 0:
            self._dist_broadcast_coalesced(module_states,
                                           self.broadcast_bucket_size)

        self._ddp_init_helper()
示例#9
0
def _to_device_index(devices):
    if not devices:
        raise RuntimeError("Cannot replicate using an empty device list.")

    if isinstance(devices, list) and isinstance(devices[0], list):
        device_ids = []
        seen = set()
        for i, replica_devs in enumerate(devices):
            assert len(replica_devs) == len(devices[0]), (
                "Cannot replicate to unidentical number of devices, but got "
                "device list {} and {} for replica {} and {}.").format(
                    devices[0], devices[i], 0, i)

            assert len(seen.intersection(replica_devs)) == 0, (
                "Devices {} are shared by multiple replicas.").format(
                    seen.intersection(replica_devs))
            seen.update(replica_devs)

            device_ids.append(_to_device_index(replica_devs))
        return device_ids
    else:
        assert len(devices) == len(
            set(devices)), ("Duplicated device ids {}.").format(devices)

        return list(map(lambda x: _get_device_index(x, True), devices))
示例#10
0
def _check_balance(device_ids):
    imbalance_warn = """
    There is an imbalance between your GPUs. You may want to exclude GPU {} which
    has less than 75% of the memory or cores of GPU {}. You can do so by setting
    the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
    environment variable."""
    device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
    #print("_check_balance device_ids:", device_ids)
    dev_props = [torch.cuda.get_device_properties(i) for i in device_ids]

    #print("_check_balance dev_props:", dev_props)

    def warn_imbalance(get_prop):
        values = [get_prop(props) for props in dev_props]
        min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1))
        max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1))
        if min_val / max_val < 0.75:
            warnings.warn(
                imbalance_warn.format(device_ids[min_pos],
                                      device_ids[max_pos]))
            return True
        return False

    #print("warn_imbalance total_memory:", warn_imbalance(lambda props: props.total_memory), \
    #      "warn_imbalance multi_processor_count:", warn_imbalance(lambda props: props.multi_processor_count), \
    #      "multi_processor_count:", props.multi_processor_count)
    if warn_imbalance(lambda props: props.total_memory):
        return
    if warn_imbalance(lambda props: props.multi_processor_count):
        return
示例#11
0
    def replicate_module(self, module: torch.nn.Module,
                         devices: List[int]) -> List[torch.nn.Module]:
        assert self.n_mask_samples % len(devices) == 0
        copies = replicate(module, devices)

        def walk(module: torch.nn.Module, copy: torch.nn.Module):
            module_map = {id(module): copy}
            for name, ref in module._modules.items():
                module_map.update(walk(ref, getattr(copy, name)))

            return module_map

        devices = [_get_device_index(d) for d in devices]

        # Copy the custom parameters
        all_params = [p.get() for p in self.pointer_values]

        if (not self.masking_enabled) or (not self.training):
            scattered = _broadcast_coalesced_reshape(all_params, devices)
        else:
            # Here is more complicated, because there might be non-masked parameters which has to be handled in the
            # usual way
            masked_indices = [
                i for i, n in enumerate(self.param_names) if self.is_masked(n)
            ]
            simple_indices = [
                i for i, n in enumerate(self.param_names)
                if not self.is_masked(n)
            ]

            masked_params = scatter([all_params[i] for i in masked_indices],
                                    devices)
            simple_params = _broadcast_coalesced_reshape(
                [all_params[i] for i in simple_indices], devices)

            scattered = [[None] * len(all_params) for _ in devices]
            for d in range(len(devices)):
                for mi, mp in zip(masked_indices, masked_params[d]):
                    scattered[d][mi] = mp

                for si, sp in zip(simple_indices, simple_params[d]):
                    scattered[d][si] = sp

        for i, c in enumerate(copies):
            device_map = walk(module, c)
            for j, p in enumerate(self.pointer_values):
                setattr(device_map[id(p.parent)], p.name, scattered[i][j])

            self.update_rnn_params(c)

        return copies
示例#12
0
 def forward(ctx, target_device, dim, *inputs):
     assert all(map(lambda i: i.is_cuda, inputs))
     target_device = _get_device_index(target_device, True)
     ctx.target_device = target_device
     ctx.dim = dim
     ctx.input_gpus = tuple(map(lambda i: i.get_device(), inputs))
     if all(t.dim() == 0 for t in inputs) and dim == 0:
         inputs = tuple(t.view(1) for t in inputs)
         warnings.warn('Was asked to gather along dimension 0, but all '
                       'input tensors were scalars; will instead unsqueeze '
                       'and return a vector.')
         ctx.unsqueezed_scalar = True
     else:
         ctx.unsqueezed_scalar = False
     ctx.input_sizes = tuple(map(lambda i: i.size(ctx.dim), inputs))
     return comm.gather(inputs, ctx.dim, ctx.target_device)
示例#13
0
 def forward(ctx, target_gpus, chunk_sizes, dim, input):
     target_gpus = list(map(lambda x: _get_device_index(x, True), target_gpus))
     ctx.dim = dim
     ctx.input_device = input.get_device() if input.is_cuda else -1
     streams = None
     if ctx.input_device == -1:
         # Perform CPU to GPU copies in a background stream
         streams = [_get_stream(device) for device in target_gpus]
     outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
     # Synchronize with the copy stream
     if streams is not None:
         for i, output in enumerate(outputs):
             with torch.cuda.device(target_gpus[i]):
                 main_stream = torch.cuda.current_stream()
                 main_stream.wait_stream(streams[i])
                 output.record_stream(main_stream)
     return outputs
示例#14
0
 def forward(ctx, target_gpus, *inputs):
     if not all(input.is_cuda for input in inputs):
         raise TypeError('Broadcast function not implemented for CPU tensors')
     target_gpus = list(map(lambda x: _get_device_index(x, True), target_gpus))
     ctx.target_gpus = target_gpus
     if len(inputs) == 0:
         return tuple()
     ctx.num_inputs = len(inputs)
     ctx.input_device = inputs[0].get_device()
     outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
     non_differentiables = []
     for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]):
         if not input_requires_grad:
             for output in outputs:
                 non_differentiables.append(output[idx])
     ctx.mark_non_differentiable(*non_differentiables)
     return tuple([t for tensors in outputs for t in tensors])
示例#15
0
 def __init__(self,
              module,
              device_ids,
              dim=0,
              broadcast_buffers=False,
              find_unused_parameters=False,
              **kwargs):
     super().__init__()
     assert len(device_ids) == 1, (
         'Currently, DistributedDataParallelWrapper only supports one'
         'single CUDA device for each process.'
         f'The length of device_ids must be 1, but got {len(device_ids)}.')
     self.module = module
     self.dim = dim
     self.to_ddp(device_ids=device_ids,
                 dim=dim,
                 broadcast_buffers=broadcast_buffers,
                 find_unused_parameters=find_unused_parameters,
                 **kwargs)
     self.output_device = _get_device_index(device_ids[0], True)
示例#16
0
    def __init__(self,
                 module,
                 split_size,
                 device_ids=None,
                 output_device=None):
        """
        This method instantiates a ModelParallel Module from a module instance passed by the user. The
        model must have a single input (forward(x) type of signature for the forward method) otherwise an error is
        returned.

        An example is here:

        .. code-block:: python

            from eisen.utils import ModelParallel
            from eisen.models.segmentation import UNet


            model = ModelParallel(
                module=UNet(input_channels=1, output_channels=1),
                split_size=2,
                device_ids=[0, 1, 2, 3],
                output_device=0
            )


        :param module: an instance of the model that should be parallelized
        :type module: torch.nn.Module
        :param split_size: split size for pipelined execution
        :type split_size: int
        :param device_ids: list of int or torch devices indicating GPUs to use
        :type device_ids: list
        :param output_device: int or torch device indicating output devices
        :type output_device: int or torch device
        """
        super(ModelParallel, self).__init__()

        module_argument_list = inspect.getfullargspec(module.forward)[0]

        if len(module_argument_list) > 2:
            raise NotImplementedError(
                'Support for modules with more than one input is not yet implemented.'
            )

        self.first_run = True
        self.split_size = split_size

        if not torch.cuda.is_available():
            self.module = module
            self.device_ids = []
            self.first_run = False
            return

        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        if output_device is None:
            output_device = device_ids[0]

        self.module = module
        self.device_ids = list(
            map(lambda x: _get_device_index(x, True), device_ids))
        self.output_device = _get_device_index(output_device, True)

        if len(self.device_ids) == 1:
            self.first_run = False
            self.module.cuda(device_ids[0])
示例#17
0
def replicate(network, devices, detach=False):
    if not _replicatable_module(network):
        raise RuntimeError("Cannot replicate network where python modules are "
                           "childrens of ScriptModule")

    devices = list(map(lambda x: _get_device_index(x, True), devices))
    num_replicas = len(devices)

    params = list(network.parameters())
    param_indices = {param: idx for idx, param in enumerate(params)}
    param_copies = _broadcast_coalesced_reshape(params, devices, detach)

    buffers = list(network.buffers())
    buffers_rg = []
    buffers_not_rg = []
    for buf in buffers:
        if buf.requires_grad and not detach:
            buffers_rg.append(buf)
        else:
            buffers_not_rg.append(buf)

    buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)}
    buffer_indices_not_rg = {
        buf: idx
        for idx, buf in enumerate(buffers_not_rg)
    }

    buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg,
                                                    devices,
                                                    detach=detach)
    buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg,
                                                        devices,
                                                        detach=True)

    modules = list(network.modules())
    module_copies = [[] for device in devices]
    module_indices = {}
    scriptmodule_skip_attr = {
        "_parameters", "_buffers", "_modules", "forward", "_c"
    }

    for i, module in enumerate(modules):
        module_indices[module] = i
        for j in range(num_replicas):
            if _is_script_module(module):
                # we have to initialize ScriptModule properly so that
                # it works with pybind11
                def init_fn(script_module):
                    # Don't do anything here, we'll initialize the ScriptModule below
                    return

                replica = torch.jit.RecursiveScriptModule._construct(
                    module._c._replicate_for_data_parallel(), init_fn)
            else:
                replica = module._replicate_for_data_parallel()

            module_copies[j].append(replica)

    for i, module in enumerate(modules):
        for key, child in module._modules.items():
            if child is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = None
            else:
                module_idx = module_indices[child]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    setattr(replica, key, module_copies[j][module_idx])
        for key, param in module._parameters.items():
            if param is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = None
            else:
                param_idx = param_indices[param]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    param = param_copies[j][param_idx]
                    setattr(
                        replica, key,
                        Parameter(param, requires_grad=param.requires_grad))
                    # TODO: We need to manually set _parameters with a bare
                    # non-parameter Tensor, otherwise gradients don't
                    # accumulate in the original parameters when you call
                    # backwards() on the DataParallel module.
                    replica._parameters[key] = param
        for key, buf in module._buffers.items():
            if buf is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = None
            else:
                if buf.requires_grad and not detach:
                    buffer_copies = buffer_copies_rg
                    buffer_idx = buffer_indices_rg[buf]
                else:
                    buffer_copies = buffer_copies_not_rg
                    buffer_idx = buffer_indices_not_rg[buf]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    setattr(replica, key, buffer_copies[j][buffer_idx])

    return [module_copies[j][0] for j in range(num_replicas)]
示例#18
0
def replicate(network, devices, detach=False):
    if not _replicatable_module(network):
        raise RuntimeError("Cannot replicate network where python modules are "
                           "childrens of ScriptModule")

    devices = list(map(lambda x: _get_device_index(x, True), devices))
    num_replicas = len(devices)

    params = list(network.parameters())
    param_indices = {param: idx for idx, param in enumerate(params)}
    param_copies = _broadcast_coalesced_reshape(params, devices, detach)

    buffers = list(network.buffers())
    buffers_rg = []
    buffers_not_rg = []
    for buf in buffers:
        if buf.requires_grad and not detach:
            buffers_rg.append(buf)
        else:
            buffers_not_rg.append(buf)

    buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)}
    buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)}

    buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach)
    buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True)

    modules = list(network.modules())
    module_copies = [[] for device in devices]
    module_indices = {}
    scriptmodule_skip_attr = {"_parameters", "_buffers", "_modules", "forward", "_c"}

    for i, module in enumerate(modules):
        module_indices[module] = i
        for j in range(num_replicas):
            replica = module._replicate_for_data_parallel()
            # This is a temporary fix for DDP. DDP needs to access the 
            # replicated model parameters. It used to do so through 
            # `mode.parameters()`. The fix added in #33907 for DP stops the
            # `parameters()` API from exposing the replicated parameters.
            # Hence, we add a `_former_parameters` dict here to support DDP.
            replica._former_parameters = OrderedDict()

            module_copies[j].append(replica)

    for i, module in enumerate(modules):
        for key, child in module._modules.items():
            if child is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = None
            else:
                module_idx = module_indices[child]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    setattr(replica, key, module_copies[j][module_idx])
        for key, param in module._parameters.items():
            if param is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = None
            else:
                param_idx = param_indices[param]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    param = param_copies[j][param_idx]
                    # parameters in replicas are no longer leaves,
                    # so setattr them as non-parameter attributes
                    setattr(replica, key, param)
                    # expose the parameter for DDP
                    replica._former_parameters[key] = param
        for key, buf in module._buffers.items():
            if buf is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = None
            else:
                if buf.requires_grad and not detach:
                    buffer_copies = buffer_copies_rg
                    buffer_idx = buffer_indices_rg[buf]
                else:
                    buffer_copies = buffer_copies_not_rg
                    buffer_idx = buffer_indices_not_rg[buf]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    setattr(replica, key, buffer_copies[j][buffer_idx])

    return [module_copies[j][0] for j in range(num_replicas)]
示例#19
0
def replicate(network, devices, detach=False):
    if not _replicatable_module(network):
        raise RuntimeError("Cannot replicate network where python modules are "
                           "childrens of ScriptModule")

    devices = list(map(lambda x: _get_device_index(x, True), devices))
    num_replicas = len(devices)

    params = list(network.parameters())
    param_indices = {param: idx for idx, param in enumerate(params)}
    param_copies = _broadcast_coalesced_reshape(params, devices, detach)

    buffers = list(network.buffers())
    buffers_rg = []
    buffers_not_rg = []
    for buf in buffers:
        if buf.requires_grad and not detach:
            buffers_rg.append(buf)
        else:
            buffers_not_rg.append(buf)

    buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)}
    buffer_indices_not_rg = {
        buf: idx
        for idx, buf in enumerate(buffers_not_rg)
    }

    buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg,
                                                    devices,
                                                    detach=detach)
    buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg,
                                                        devices,
                                                        detach=True)

    modules = list(network.modules())
    module_copies = [[] for device in devices]
    module_indices = {}
    scriptmodule_skip_attr = {
        "_parameters", "_buffers", "_modules", "forward", "_c"
    }

    for i, module in enumerate(modules):
        module_indices[module] = i
        for j in range(num_replicas):
            if _is_script_module(module):
                # we have to initialize ScriptModule properly so that
                # it works with pybind11
                def init_fn(script_module):
                    # Don't do anything here, we'll initialize the ScriptModule below
                    return

                replica = torch.jit.RecursiveScriptModule._construct(
                    module._c._replicate_for_data_parallel(), init_fn)
            else:
                replica = module._replicate_for_data_parallel()

            # This is a temporary fix for DDP. DDP needs to access the
            # replicated model parameters. It used to do so through
            # `mode.parameters()`. The fix added in #33907 for DP stops the
            # `parameters()` API from exposing the replicated parameters.
            # Hence, we add a `_former_parameters` dict here to support DDP.
            replica._former_parameters = OrderedDict()

            module_copies[j].append(replica)

    for i, module in enumerate(modules):
        for key, child in module._modules.items():
            if child is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = None
            else:
                module_idx = module_indices[child]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    setattr(replica, key, module_copies[j][module_idx])
        for key, param in module._parameters.items():
            if param is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = None
            else:
                param_idx = param_indices[param]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    param = param_copies[j][param_idx]
                    # parameters in replicas are no longer leaves, so remove them from _parameters
                    # and setattr them as non-parameter attributes
                    # scripted modules don't allow deleting parameters, but also don't complain
                    # on assigning non-Parameter type
                    if (not _is_script_module(replica)):
                        del replica._parameters[key]
                    setattr(replica, key, param)
                    # expose the parameter for DDP
                    replica._former_parameters[key] = param
        for key, buf in module._buffers.items():
            if buf is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = None
            else:
                if buf.requires_grad and not detach:
                    buffer_copies = buffer_copies_rg
                    buffer_idx = buffer_indices_rg[buf]
                else:
                    buffer_copies = buffer_copies_not_rg
                    buffer_idx = buffer_indices_not_rg[buf]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    setattr(replica, key, buffer_copies[j][buffer_idx])

    return [module_copies[j][0] for j in range(num_replicas)]
示例#20
0
def replicate(network, devices, detach=False):
    if not _replicatable_module(network):
        raise RuntimeError("Cannot replicate network where python modules are "
                           "childrens of ScriptModule")

    devices = list(map(lambda x: _get_device_index(x, True), devices))
    num_replicas = len(devices)

    params = list(network.parameters())
    param_indices = {param: idx for idx, param in enumerate(params)}
    param_copies = _broadcast_coalesced_reshape(params, devices, detach)

    buffers = list(network.buffers())
    buffers_rg = []
    buffers_not_rg = []
    for buf in buffers:
        if buf.requires_grad and not detach:
            buffers_rg.append(buf)
        else:
            buffers_not_rg.append(buf)

    buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)}
    buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)}

    buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach)
    buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True)

    modules = list(network.modules())
    module_copies = [[] for device in devices]
    module_indices = {}
    scriptmodule_skip_attr = {"_parameters", "_buffers", "_modules"}

    for i, module in enumerate(modules):
        module_indices[module] = i
        for j in range(num_replicas):
            if _is_script_module(module):
                # we have to initialize ScriptModule properly so that
                # it works with pybind11
                replica = _init_script_module()
                keys = set(module.__dict__.keys()) - scriptmodule_skip_attr
                for key in keys:
                    replica.__dict__[key] = module.__dict__[key]
            else:
                replica = module.__new__(type(module))
                replica.__dict__ = module.__dict__.copy()
                replica._parameters = replica._parameters.copy()
                replica._buffers = replica._buffers.copy()
                replica._modules = replica._modules.copy()

            module_copies[j].append(replica)

    for i, module in enumerate(modules):
        for key, child in module._modules.items():
            if child is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = None
            else:
                module_idx = module_indices[child]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = module_copies[j][module_idx]
        for key, param in module._parameters.items():
            if param is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = None
            else:
                param_idx = param_indices[param]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = param_copies[j][param_idx]
        for key, buf in module._buffers.items():
            if buf is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = None
            else:
                if buf.requires_grad and not detach:
                    buffer_copies = buffer_copies_rg
                    buffer_idx = buffer_indices_rg[buf]
                else:
                    buffer_copies = buffer_copies_not_rg
                    buffer_idx = buffer_indices_not_rg[buf]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = buffer_copies[j][buffer_idx]

    for j in range(num_replicas):
        _copy_scriptmodule_methods(modules, module_copies[j], module_indices)

    return [module_copies[j][0] for j in range(num_replicas)]
示例#21
0
def criterion_parallel_apply(modules,
                             inputs,
                             targets,
                             kwargs_tup=None,
                             devices=None):
    assert len(modules) == len(inputs)
    assert len(targets) == len(inputs)
    if kwargs_tup is not None:
        assert len(modules) == len(kwargs_tup)
    else:
        kwargs_tup = ({}, ) * len(modules)
    if devices is not None:
        assert len(modules) == len(devices)
    else:
        devices = [None] * len(modules)
    devices = list(map(lambda x: _get_device_index(x, True), devices))
    lock = threading.Lock()
    results = {}
    grad_enabled = torch.is_grad_enabled()

    def _worker(i, module, input, target, kwargs, device=None):
        torch.set_grad_enabled(grad_enabled)
        if device is None:
            device = get_a_var(input).get_device()
        try:
            with torch.cuda.device(device):
                if not isinstance(input, (list, tuple)):
                    input = (input, )
                if not isinstance(target, (list, tuple)):
                    target = (target, )
                output = module(*input, *target, **kwargs)
            with lock:
                results[i] = output
        except Exception:
            with lock:
                results[i] = ExceptionWrapper(
                    where="in replica {} on device {}".format(i, device))

    if len(modules) > 1:
        threads = [
            threading.Thread(target=_worker,
                             args=(i, module, input, target, kwargs, device))
            for i, (
                module, input, target, kwargs,
                device) in enumerate(zip(modules, inputs, kwargs_tup, devices))
        ]

        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
    else:
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])

    outputs = []
    for i in range(len(inputs)):
        output = results[i]
        if isinstance(output, ExceptionWrapper):
            output.reraise()
        outputs.append(output)
    return outputs
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
    r"""Applies each `module` in :attr:`modules` in parallel on arguments
    contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
    on each of :attr:`devices`.
    Args:
        modules (Module): modules to be parallelized
        inputs (tensor): inputs to the modules
        devices (list of int or torch.device): CUDA devices
    :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
    :attr:`devices` (if given) should all have same length. Moreover, each
    element of :attr:`inputs` can either be a single object as the only argument
    to a module, or a collection of positional arguments.
    """
    assert len(modules) == len(inputs)
    if kwargs_tup is not None:
        assert len(modules) == len(kwargs_tup)
    else:
        kwargs_tup = ({}, ) * len(modules)
    if devices is not None:
        assert len(modules) == len(devices)
    else:
        devices = [None] * len(modules)
    devices = list(map(lambda x: _get_device_index(x, True), devices))
    lock = threading.Lock()
    results = {}
    grad_enabled = torch.is_grad_enabled()

    def _worker(i, module, input, kwargs, device=None):
        torch.set_grad_enabled(grad_enabled)
        if device is None:
            device = get_a_var(input).get_device()
        try:
            with torch.cuda.device(device):
                # this also avoids accidental slicing of `input` if it is a Tensor
                #if not isinstance(input, (list, tuple)):
                #    input = (input,)
                output = module(input)
            with lock:
                results[i] = output
        except Exception:
            with lock:
                results[i] = ExceptionWrapper(
                    where="in replica {} on device {}".format(i, device))

    if len(modules) > 1:
        threads = [
            threading.Thread(target=_worker,
                             args=(i, module, input, kwargs, device))
            for i, (
                module, input, kwargs,
                device) in enumerate(zip(modules, inputs, kwargs_tup, devices))
        ]

        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
    else:
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])

    outputs = []
    for i in range(len(inputs)):
        output = results[i]
        if isinstance(output, ExceptionWrapper):
            output.reraise()
        outputs.append(output)
    return outputs
示例#23
0
def replicate(network, devices, detach=False):
    from ._functions import Broadcast

    if not _replicatable_module(network):
        raise RuntimeError("Cannot replicate network where python modules are "
                           "childrens of ScriptModule")

    devices = list(map(lambda x: _get_device_index(x, True), devices))
    num_replicas = len(devices)

    params = list(network.parameters())
    param_indices = {param: idx for idx, param in enumerate(params)}
    param_copies = Broadcast.apply(devices, *params)
    if len(params) > 0:
        param_copies = [
            param_copies[i:i + len(params)]
            for i in range(0, len(param_copies), len(params))
        ]

    buffers = list(network.buffers())
    buffer_indices = {buf: idx for idx, buf in enumerate(buffers)}
    buffer_copies = comm.broadcast_coalesced(buffers, devices)

    modules = list(network.modules())
    module_copies = [[] for device in devices]
    module_indices = {}
    scriptmodule_skip_attr = {"_parameters", "_buffers", "_modules"}

    for i, module in enumerate(modules):
        module_indices[module] = i
        for j in range(num_replicas):
            if _is_script_module(module):
                # we have to initialize ScriptModule properly so that
                # it works with pybind11
                replica = _init_script_module()
                keys = set(module.__dict__.keys()) - scriptmodule_skip_attr
                for key in keys:
                    replica.__dict__[key] = module.__dict__[key]
            else:
                replica = module.__new__(type(module))
                replica.__dict__ = module.__dict__.copy()
                replica._parameters = replica._parameters.copy()
                replica._buffers = replica._buffers.copy()
                replica._modules = replica._modules.copy()

            module_copies[j].append(replica)

    for i, module in enumerate(modules):
        for key, child in module._modules.items():
            if child is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = None
            else:
                module_idx = module_indices[child]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = module_copies[j][module_idx]
        for key, param in module._parameters.items():
            if param is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = None
            else:
                param_idx = param_indices[param]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = param_copies[j][param_idx].detach() \
                        if detach else param_copies[j][param_idx]
        for key, buf in module._buffers.items():
            if buf is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = None
            else:
                buffer_idx = buffer_indices[buf]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = buffer_copies[j][buffer_idx]

    for j in range(num_replicas):
        _copy_scriptmodule_methods(modules, module_copies[j], module_indices)

    return [module_copies[j][0] for j in range(num_replicas)]
示例#24
0
    def replicate(network, devices, detach=False):
        if not _replicatable_module(network):
            raise RuntimeError(
                "Cannot replicate network where python modules are "
                "childrens of ScriptModule"
            )

        devices = list(map(lambda x: _get_device_index(x, True), devices))
        num_replicas = len(devices)

        params = list(network.parameters())
        param_indices = {param: idx for idx, param in enumerate(params)}
        param_copies = _broadcast_coalesced_reshape(params, devices, detach)

        buffers = list(network.buffers())
        buffers_rg = []
        buffers_not_rg = []
        for buf in buffers:
            if buf.requires_grad and not detach:
                buffers_rg.append(buf)
            else:
                buffers_not_rg.append(buf)

        buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)}
        buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)}

        buffer_copies_rg = _broadcast_coalesced_reshape(
            buffers_rg, devices, detach=detach
        )
        buffer_copies_not_rg = _broadcast_coalesced_reshape(
            buffers_not_rg, devices, detach=True
        )

        modules = list(network.modules())
        module_copies = [[] for _ in devices]
        module_indices = {}
        scriptmodule_skip_attr = {
            "_parameters",
            "_buffers",
            "_modules",
            "forward",
            "_c",
        }

        for i, module in enumerate(modules):
            module_indices[module] = i
            for j in range(num_replicas):
                if _is_script_module(module):
                    # we have to initialize ScriptModule properly so that
                    # it works with pybind11
                    replica = _init_script_module()

                    attribute_names = set(
                        entry[0] for entry in module._c._get_attributes()
                    )

                    keys = (
                        set(module.__dict__.keys())
                        - scriptmodule_skip_attr
                        - attribute_names
                    )
                    for key in keys:
                        if not _is_script_method(module.__dict__[key]):
                            replica.__dict__[key] = module.__dict__[key]
                    for name, the_type, value in module._c._get_attributes():
                        if name in module._buffers.keys():
                            continue
                        replica._c._register_attribute(name, the_type, value)
                else:
                    replica = module.__new__(type(module))
                    replica.__dict__ = module.__dict__.copy()
                    replica._parameters = replica._parameters.copy()
                    replica._buffers = replica._buffers.copy()
                    replica._modules = replica._modules.copy()

                module_copies[j].append(replica)

        for i, module in enumerate(modules):
            for key, child in module._modules.items():
                if child is None:
                    for j in range(num_replicas):
                        replica = module_copies[j][i]
                        replica._modules[key] = None
                else:
                    module_idx = module_indices[child]
                    for j in range(num_replicas):
                        replica = module_copies[j][i]
                        replica._modules[key] = module_copies[j][module_idx]
            for key, param in module._parameters.items():
                if param is None:
                    for j in range(num_replicas):
                        replica = module_copies[j][i]
                        replica._parameters[key] = None
                else:
                    param_idx = param_indices.get(param, None)
                    if param_idx is None:
                        continue
                    for j in range(num_replicas):
                        replica = module_copies[j][i]
                        replica._parameters[key] = param_copies[j][param_idx]
            for key, buf in module._buffers.items():
                if buf is None:
                    for j in range(num_replicas):
                        replica = module_copies[j][i]
                        replica._buffers[key] = None
                else:
                    if buf.requires_grad and not detach:
                        buffer_copies = buffer_copies_rg
                        buffer_idx = buffer_indices_rg.get(buf, None)
                    else:
                        buffer_copies = buffer_copies_not_rg
                        buffer_idx = buffer_indices_not_rg.get(buf, None)
                    if buffer_idx is None:
                        continue
                    for j in range(num_replicas):
                        replica = module_copies[j][i]
                        replica._buffers[key] = buffer_copies[j][buffer_idx]

        for j in range(num_replicas):
            _copy_scriptmodule_methods(modules, module_copies[j], module_indices)

        replicas = [module_copies[j][0] for j in range(num_replicas)]
        for model_replica in replicas:
            for _, submodule in model_replica.named_modules():
                if hasattr(submodule, "on_replicate") and callable(
                    submodule.on_replicate
                ):
                    submodule.on_replicate()
        return replicas
示例#25
0
    def __init__(self, module, device_ids=None,
            broadcast_buffers=True,
            compression=Compression.none
            ):
        super(DistributedDataParallel, self).__init__()

        assert device_ids and len(device_ids) == 1, (
                "DistributedDataParallel device_ids contain exactlyone entry,"
                " but got {}.").format(device_ids)
        self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
        self.module = module
        self.broadcast_buffers = broadcast_buffers
        self.require_forward_param_sync = broadcast_buffers
        self._handles = {}
        self._grad_accs = []
        self._requires_update = set()
        self._num_grads = 1

        self.modules_buffers = [list(self.module.buffers())]
        self._compression = compression
        self._enable_async = False
        self._require_backward_grad_sync = True
        named_parameters = self.module.named_parameters()
        named_parameters = list(named_parameters)
        if len(named_parameters) > 0:
            if isinstance(named_parameters[0][1], torch.Tensor):
                if any([not isinstance(p, torch.Tensor) for name, p in named_parameters]):
                    raise ValueError('named_parameters should consistently be a sequence of '
                                     'tuples (name, torch.Tensor)')
                self._is_tensor_instance = True
                # there is an issue when using torch.Tensor as key, so use its hash instead
                # https://github.com/pytorch/pytorch/issues/7733
                self._parameter_names = {v.__hash__(): k for k, v
                                         in sorted(named_parameters)}
                self._tensor_list = [tensor for name, tensor in named_parameters]
            else:
                self._is_tensor_instance = False
                self._parameter_names = {v: k for k, v
                                         in sorted(named_parameters)}
        else:
            self._is_tensor_instance = False
            self._parameter_names = {v: 'push_pull.noname.%s' % i
                                     for param_group in self.param_groups
                                     for i, v in enumerate(param_group['params'])}
        if size() > 1:
            self._register_hooks()
            named_params = self.module.named_parameters()
            self._num_grads = sum(p.requires_grad for _, p in named_params)
            byteps_torch_set_num_grads(self._num_grads)

        # declare tensors
        for name in sorted(self._parameter_names.values()):
            declare("Gradient."+name)
        # We use two loops for load-balancing
        for name in sorted(self._parameter_names.values()):
            declare("Parameter."+name)

        # broadcast model state
        module_states = list(self.module.state_dict().values())
        if len(module_states) > 0:
            bps.torch.broadcast_parameters(self.module.state_dict(), root_rank=0)
示例#26
0
def replicate(network, devices, detach=False):
    from ._functions import Broadcast

    devices = list(map(lambda x: _get_device_index(x, True), devices))
    num_replicas = len(devices)

    params = list(network.parameters())
    param_indices = {param: idx for idx, param in enumerate(params)}
    param_copies = Broadcast.apply(devices, *params)
    if len(params) > 0:
        param_copies = [
            param_copies[i:i + len(params)]
            for i in range(0, len(param_copies), len(params))
        ]

    buffers = list(network.buffers())
    buffer_indices = {buf: idx for idx, buf in enumerate(buffers)}
    buffer_copies = comm.broadcast_coalesced(buffers, devices)

    modules = list(network.modules())
    module_copies = [[] for device in devices]
    module_indices = {}

    for i, module in enumerate(modules):
        module_indices[module] = i
        for j in range(num_replicas):
            replica = module.__new__(type(module))
            replica.__dict__ = module.__dict__.copy()
            replica._parameters = replica._parameters.copy()
            replica._buffers = replica._buffers.copy()
            replica._modules = replica._modules.copy()
            module_copies[j].append(replica)

    for i, module in enumerate(modules):
        for key, child in module._modules.items():
            if child is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = None
            else:
                module_idx = module_indices[child]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = module_copies[j][module_idx]
        for key, param in module._parameters.items():
            if param is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = None
            else:
                param_idx = param_indices[param]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = param_copies[j][param_idx].detach() \
                        if detach else param_copies[j][param_idx]
        for key, buf in module._buffers.items():
            if buf is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = None
            else:
                buffer_idx = buffer_indices[buf]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = buffer_copies[j][buffer_idx]

    return [module_copies[j][0] for j in range(num_replicas)]
示例#27
0
def parallel_apply_predict(modules, inputs, kwargs_tup=None, devices=None):
    """Applies each `module` predict method in `modules` in parallel on arguments
    contained in `inputs` (positional) and `kwargs_tup` (keyword)
    on each of `devices`.

    Args:
        modules: modules to be parallelized.
        inputs: inputs to the modules.
        devices: CUDA devices.

    """
    assert len(modules) == len(inputs)
    if kwargs_tup is not None:
        assert len(modules) == len(kwargs_tup)
    else:
        kwargs_tup = ({},) * len(modules)
    if devices is not None:
        assert len(modules) == len(devices)
    else:
        devices = [None] * len(modules)
    devices = list(map(lambda x: _get_device_index(x, True), devices))
    lock = threading.Lock()
    results = {}
    grad_enabled = torch.is_grad_enabled()

    def _worker(i, module, input, kwargs, device=None):
        torch.set_grad_enabled(grad_enabled)
        if device is None:
            device = get_a_var(input).get_device()
        try:
            with torch.cuda.device(device):
                # this also avoids accidental slicing of `input` if it is a Tensor
                if not isinstance(input, (list, tuple)):
                    input = (input,)
                output = module.predict(*input, **kwargs)
            with lock:
                results[i] = output
        except Exception:
            with lock:
                results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device))

    if len(modules) > 1:
        threads = [
            threading.Thread(target=_worker, args=(i, module, input, kwargs, device))
            for i, (module, input, kwargs, device) in enumerate(zip(modules, inputs, kwargs_tup, devices))
        ]

        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
    else:
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])

    outputs = []
    for i in range(len(inputs)):
        output = results[i]
        if isinstance(output, ExceptionWrapper):
            output.reraise()
        outputs.append(output)
    return outputs
示例#28
0
    def __init__(self,
                 module,
                 device_ids=None,
                 output_device=None,
                 dim=0,
                 broadcast_buffers=True,
                 process_group=None,
                 bucket_cap_mb=25):

        super(DistributedDataParallel, self).__init__()

        # Use all devices by default
        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))

        if output_device is None:
            output_device = device_ids[0]

        if process_group is None:
            self.process_group = dist.get_default_group()
        else:
            self.process_group = process_group

        self.dim = dim
        self.module = module
        self.device_ids = list(
            map(lambda x: _get_device_index(x, True), device_ids))
        self.output_device = _get_device_index(output_device, True)
        self.broadcast_buffers = broadcast_buffers

        MB = 1024 * 1024

        # used for intra-node param sync and inter-node sync as well
        self.broadcast_bucket_size = 250 * MB

        # Sync params and buffers
        module_states = list(self.module.state_dict().values())
        if len(module_states) > 0:
            self._dist_broadcast_coalesced(module_states,
                                           self.broadcast_bucket_size)

        if len(device_ids) > 1:
            # TODO: we don't need to replicate params in here. they're always going to
            # be broadcasted using larger blocks in broadcast_coalesced, so it might be
            # better to not pollute the caches with these small blocks
            self._module_copies = replicate(self.module,
                                            self.device_ids,
                                            detach=True)
            self._module_copies[0] = self.module

            for module_copy in self._module_copies[1:]:
                for param, copy_param in zip(self.module.parameters(),
                                             module_copy.parameters()):
                    copy_param.requires_grad = param.requires_grad

        else:
            self._module_copies = [self.module]

        self.modules_params_data = [[] for _ in range(len(self.device_ids))]
        self.modules_buffers_data = [[] for _ in range(len(self.device_ids))]

        for dev_idx, module in enumerate(self._module_copies):
            self.modules_params_data[dev_idx] = [
                p.data for p in module.parameters()
            ]
            self.modules_buffers_data[dev_idx] = [
                b.data for b in module.buffers()
            ]

        bucket_bytes_cap = bucket_cap_mb * MB

        # This is a triply-nested list where the "dimensions" are: devices, buckets, bucket_elems
        param_buckets = []
        # Split the parameters into buckets and by types as well
        param_buckets = [
            list(_take_tensors(m.parameters(), bucket_bytes_cap))
            for m in self._module_copies
        ]

        self.bucket_sizes = []
        self.bucket_map = {}

        # We transpose param_buckets, so the loop is over buckets.
        # param_buckets_tuple is a doubly-nested list with "dims": devices, bucket_elems
        for bucket_idx, param_buckets_tuple in enumerate(zip(*param_buckets)):
            self.bucket_sizes.append(0)
            # Now, we transpose again, so we iterate over bucket_elems, but getting tuples
            # of params from each device.
            for param_tuple in zip(*param_buckets_tuple):
                if not param_tuple[0].requires_grad:
                    continue
                for p in param_tuple:
                    self.bucket_map[p] = (bucket_idx,
                                          self.bucket_sizes[bucket_idx])
                self.bucket_sizes[bucket_idx] += 1

        self.buckets = [[[None for _ in range(self.bucket_sizes[i])]
                         for _ in range(len(self.device_ids))]
                        for i in range(len(self.bucket_sizes))]
        # The number of params ready in each bucket
        self.buckets_ready_size = [[0 for _ in range(len(self.device_ids))]
                                   for i in range(len(self.bucket_sizes))]

        # coalesced bucket for only device 0
        self.buckets_coalesced = [[] for _ in range(len(self.bucket_sizes))]
        # We will always reduce the bucket following the reverse order
        # that is, alway reduces following the order of: n - 1, n - 2, ..., 0
        self.next_bucket = len(self.bucket_sizes) - 1
        self.ready_buckets_not_reduced = set()
        self.reduction_works = [None for _ in range(len(self.bucket_sizes))]
        self.devs_ready = [0 for _ in range(len(self.bucket_sizes))]
        self._register_grad_hooks()
示例#29
0
def parallel_apply(modules,
                   inputs,
                   kwargs_tup=None,
                   devices=None):  # pragma: no-cover
    r"""Applies each `module` in :attr:`modules` in parallel on arguments
    contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
    on each of :attr:`devices`.

    Args:
        modules (Module): modules to be parallelized
        inputs (tensor): inputs to the modules
        devices (list of int or torch.device): CUDA devices

    :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
    :attr:`devices` (if given) should all have same length. Moreover, each
    element of :attr:`inputs` can either be a single object as the only argument
    to a module, or a collection of positional arguments.
    """
    assert len(modules) == len(inputs)
    if kwargs_tup is not None:
        assert len(modules) == len(kwargs_tup)
    else:
        kwargs_tup = ({}, ) * len(modules)
    if devices is not None:
        assert len(modules) == len(devices)
    else:
        devices = [None] * len(modules)
    devices = list(map(lambda x: _get_device_index(x, True), devices))
    lock = threading.Lock()
    results = {}
    grad_enabled = torch.is_grad_enabled()

    def _worker(i, module, input, kwargs, device=None):
        torch.set_grad_enabled(grad_enabled)
        if device is None:
            device = get_a_var(input).get_device()
        try:
            with torch.cuda.device(device):
                # this also avoids accidental slicing of `input` if it is a Tensor
                if not isinstance(input, (list, tuple)):
                    input = (input, )

                # ---------------
                # CHANGE
                if module.training:
                    output = module.training_step(*input, **kwargs)

                elif module.testing:
                    output = module.test_step(*input, **kwargs)

                else:
                    output = module.validation_step(*input, **kwargs)

                if module.use_dp or module.use_ddp2:
                    auto_squeeze_dim_zeros(output)
                # ---------------

            with lock:
                results[i] = output
        except Exception as ex:
            with lock:
                results[i] = ex

    # TODO: fix hack (maybe not a hack)
    # make sure each module knows what training state it's in...
    # fixes weird bug where copies are out of sync
    root_m = modules[0]
    for m in modules[1:]:
        m.training = root_m.training
        m.testing = root_m.testing

    if len(modules) > 1:
        threads = [
            threading.Thread(target=_worker,
                             args=(i, module, input, kwargs, device))
            for i, (
                module, input, kwargs,
                device) in enumerate(zip(modules, inputs, kwargs_tup, devices))
        ]

        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
    else:
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])

    outputs = []
    for i in range(len(inputs)):
        output = results[i]
        if isinstance(output, Exception):
            raise output
        outputs.append(output)
    return outputs
    def __init__(self,
                 module,
                 device_ids=None,
                 output_device=None,
                 dim=0,
                 broadcast_buffers=True,
                 process_group=None,
                 bucket_cap_mb=25,
                 find_unused_parameters=False,
                 check_reduction=False):

        super(DistributedDataParallel, self).__init__()

        self.is_multi_device_module = len(
            {p.device
             for p in module.parameters()}) > 1
        self.is_cuda = all(
            [p.device.type == 'cuda' for p in module.parameters()])

        if not self.is_cuda or self.is_multi_device_module:
            assert not device_ids and not output_device, (
                "DistributedDataParallel device_ids and output_device arguments "
                "only work with single-device CUDA modules, but got "
                "device_ids {}, output_device {}, and module parameters {}."
            ).format(device_ids, output_device,
                     {p.device
                      for p in module.parameters()})

            self.device_ids = None
            self.output_device = None
        else:
            # Use all devices by default for single-device CUDA modules
            if device_ids is None:
                device_ids = list(range(torch.cuda.device_count()))

            self.device_ids = list(
                map(lambda x: _get_device_index(x, True), device_ids))

            # if output_device is None:
            #     output_device = device_ids[0]

            self.output_device = _get_device_index(output_device, True)

        if self.is_multi_device_module:
            assert self.is_cuda, (
                "DistributedDataParallel with multi-device module only works "
                "with CUDA devices, but module parameters locate in {}."
            ).format({p.device
                      for p in module.parameters()})

        if process_group is None:
            self.process_group = _get_default_group()
        else:
            self.process_group = process_group

        self.dim = dim
        self.module = module
        self.broadcast_buffers = broadcast_buffers
        self.find_unused_parameters = find_unused_parameters
        self.require_backward_grad_sync = True
        self.require_forward_param_sync = True

        if check_reduction:
            # This argument is no longer used since the reducer
            # will ensure reduction completes even if some parameters
            # do not receive gradients.
            pass

        MB = 1024 * 1024

        # used for intra-node param sync and inter-node sync as well
        self.broadcast_bucket_size = int(250 * MB)

        # reduction bucket size
        self.bucket_bytes_cap = int(bucket_cap_mb * MB)

        # Sync params and buffers
        module_states = list(self.module.state_dict().values())
        if len(module_states) > 0:
            self._distributed_broadcast_coalesced(module_states,
                                                  self.broadcast_bucket_size)

        self._ddp_init_helper()