Ejemplo n.º 1
0
def data_parallel(module,
                  inputs,
                  device_ids=None,
                  output_device=None,
                  dim=0,
                  module_kwargs=None):
    r"""Evaluates module(input) in parallel across the GPUs given in device_ids.
    This is the functional version of the DataParallel module.
    Args:
        module: the module to evaluate in parallel
        inputs: inputs to the module
        device_ids: GPU ids on which to replicate module
        output_device: GPU location of the output  Use -1 to indicate the CPU.
            (default: device_ids[0])
    Returns:
        a Variable containing the result of module(input) located on
        output_device
    """
    if not isinstance(inputs, tuple):
        inputs = (inputs, )

    if device_ids is None:
        device_ids = list(range(torch.cuda.device_count()))

    if output_device is None:
        output_device = device_ids[0]

    inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids,
                                           dim)
    if len(device_ids) == 1:
        return module(*inputs[0], **module_kwargs[0])
    used_device_ids = device_ids[:len(inputs)]
    replicas = replicate(module, used_device_ids)
    outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
    return gather(outputs, output_device, dim)
Ejemplo n.º 2
0
    def wrapped(self, *inputs, **module_kwargs):
        if (not hasattr(self, '_is_replica')) and inputs[0].is_cuda:
            device_count = torch.cuda.device_count()
            if inputs[0].shape[0] % device_count != 0:
                import os
                cuda_visible_devices = os.environ[
                    'CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else ''
                raise ValueError(
                    'batch size (%d) must be divisible by the number of GPUs (%d) used\n CUDA_VISIBLE_DEVICES: %s'
                    % (inputs[0].shape[0], device_count, cuda_visible_devices))
            if device_count > 1:
                # modified from pytorch (torch.nn.parallel.DataParallel)
                device_ids = list(range(device_count))
                output_device = device_ids[0]
                inputs, kwargs = scatter_kwargs(inputs, module_kwargs,
                                                device_ids)
                replicas = replicate(self, device_ids[:len(inputs)])

                # add a _is_replica flag to avoid infinite loop
                # from recursively calling parallel_apply
                for replica in replicas:
                    replica._is_replica = True
                outputs = parallel_apply(replicas, inputs, kwargs)
                return gather(outputs, output_device)

        return self._forward_worker(*inputs, **module_kwargs)
Ejemplo n.º 3
0
    def scatter(self, inputs, kwargs, device_ids):
        _inputs = []
        _kwargs = {}
        input_constructors = {}
        kwargs_constructors = {}
        for i, item in enumerate(inputs):
            if isinstance(item, Scatterable):
                input_constructors[i] = item.from_kwargs
                _inputs.append(item.to_kwargs())
            else:
                input_constructors[i] = lambda x: x
                _inputs.append(item)

        for key, item in kwargs.items():
            if isinstance(item, Scatterable):
                kwargs_constructors[key] = item.from_kwargs
                _kwargs[key] = item.to_kwargs()
            else:
                kwargs_constructors[key] = lambda x: x
                _kwargs[key] = item

        _inputs, _kwargs = scatter_kwargs(_inputs,
                                          _kwargs,
                                          device_ids,
                                          dim=self.dim)

        _inputs = [[
            input_constructors[i](item) for i, item in enumerate(_input)
        ] for _input in _inputs]
        _kwargs = [{
            k: kwargs_constructors[k](item)
            for k, item in _kwarg.items()
        } for _kwarg in _kwargs]

        return _inputs, _kwargs
Ejemplo n.º 4
0
 def wrapper(network, *inputs, **kwargs):
     inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0)
     if len(device_ids) == 1:
         return getattr(network, func_name)(*inputs[0], **kwargs[0])
     replicas = replicate(network, device_ids[:len(inputs)])
     outputs = parallel_apply(replicas, func_name, inputs, kwargs,
                              device_ids[:len(replicas)])
     return gather(outputs, output_device)
def data_parallel(module,
                  inputs,
                  device_ids=None,
                  output_device=None,
                  dim=0,
                  module_kwargs=None,
                  dont_scatter=False,
                  dont_gather=False):
    r"""Evaluates module(input) in parallel across the GPUs given in device_ids.

    This is the functional version of the DataParallel module.

    Args:
        module: the module to evaluate in parallel
        inputs: inputs to the module
        device_ids: GPU ids on which to replicate module
        output_device: GPU location of the output  Use -1 to indicate the CPU.
            (default: device_ids[0])
    Returns:
        a Variable containing the result of module(input) located on
        output_device
    """
    if not isinstance(inputs, tuple):
        inputs = (inputs, )
    #print('getting device_ids')
    if device_ids is None:
        device_ids = list(range(torch.cuda.device_count()))
    #print(device_ids)
    if output_device is None:
        output_device = device_ids[0]

    if dont_scatter == False:
        do_scatter_lists = isinstance(inputs[0], list)
        if do_scatter_lists:
            inputs, module_kwargs = scatter_lists(inputs, module_kwargs,
                                                  device_ids)
        else:
            inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs,
                                                   device_ids, dim)

    if len(device_ids) == 1:
        return module(*inputs[0], **module_kwargs[0])
    #print('getting used device_ids')
    used_device_ids = device_ids[:len(inputs)]
    #print(used_device_ids)
    #print('making model replicas')
    replicas = replicate(module, used_device_ids)
    #print('applying model')
    outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
    if dont_gather:
        return tuple([[out[i] for out in outputs]
                      for i in range(len(outputs[0]))])
    #print('gathering result')
    return gather(outputs, output_device, dim)
Ejemplo n.º 6
0
    def scatter(self, inputs, kwargs, device_ids):
        if not isinstance(self.module, MetaModule):
            return super(DataParallel, self).scatter(inputs, kwargs, device_ids)

        params = kwargs.pop('params', None)
        inputs_, kwargs_ = scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
        # Add params argument unchanged back in kwargs
        replicas = self._replicate_params(params, inputs_, device_ids,
                                          detach=not torch.is_grad_enabled())
        kwargs_ = tuple(dict(params=replica, **kwarg)
                        for (kwarg, replica) in zip(kwargs_, replicas))
        return inputs_, kwargs_
Ejemplo n.º 7
0
def data_parallel(
    module,
    inputs,
    device_ids=None,
    output_device=None,
    dim=0,
    module_kwargs=None,
    non_scatter_kwargs=None,
):
    r"""Evaluates module(input) in parallel across the GPUs given in device_ids.

    This is the functional version of the DataParallel module.

    Args:
        module (Module): the module to evaluate in parallel
        inputs (Tensor): inputs to the module
        device_ids (list of int or torch.device): GPU ids on which to replicate module
        output_device (list of int or torch.device): GPU location of the output
            Use -1 to indicate the CPU.
            (default: device_ids[0])
    Returns:
        a Tensor containing the result of module(input) located on
        output_device
    """
    if not isinstance(inputs, tuple):
        inputs = (inputs,)

    if device_ids is None:
        device_ids = list(range(torch.cuda.device_count()))

    if output_device is None:
        output_device = device_ids[0]

    device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
    output_device = _get_device_index(output_device, True)
    src_device_obj = torch.device("cuda:{}".format(device_ids[0]))

    for tensor in chain(module.parameters(), module.buffers()):
        if tensor.device != src_device_obj:
            raise RuntimeError(
                "module must have its parameters and buffers "
                "on device {} (device_ids[0]) but found one of "
                "them on device: {}".format(src_device_obj, tensor.device)
            )

    inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
    if len(device_ids) == 1:
        return module(*inputs[0], **module_kwargs[0])
    used_device_ids = device_ids[: len(inputs)]
    replicas = replicate(module, used_device_ids)
    outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
    return gather(outputs, output_device, dim)
Ejemplo n.º 8
0
    def _data_parallel(self, batch):
        """
        Do the forward pass using multiple GPUs.  This is a simplification
        of torch.nn.parallel.data_parallel to support the allennlp model
        interface.
        """
        inputs, module_kwargs = scatter_kwargs((), batch, self._cuda_devices, 0)
        used_device_ids = self._cuda_devices[:len(inputs)]
        replicas = replicate(self._model, used_device_ids)
        outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)

        # Only the 'loss' is needed.
        # a (num_gpu, ) tensor with loss on each GPU
        losses = gather([output['loss'] for output in outputs], used_device_ids[0], 0)
        return {'loss': losses.mean()}
Ejemplo n.º 9
0
    def _data_parallel(self, batch):
        """
        Do the forward pass using multiple GPUs.  This is a simplification
        of torch.nn.parallel.data_parallel to support the allennlp model
        interface.
        """
        inputs, module_kwargs = scatter_kwargs((), batch, self._cuda_devices, 0)
        used_device_ids = self._cuda_devices[:len(inputs)]
        replicas = replicate(self._model, used_device_ids)
        outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)

        # Only the 'loss' is needed.
        # a (num_gpu, ) tensor with loss on each GPU
        losses = gather([output['loss'].unsqueeze(0) for output in outputs], used_device_ids[0], 0)
        return {'loss': losses.mean()}
    def forward(self,  # type: ignore
                inputs: List[torch.Tensor],
                *targets: Tuple[torch.Tensor],
                **kwargs: Dict[str, Any]) -> torch.Tensor:
        # inputs are expected to be already scattered
        # scattering the targets instead
        if not self.device_ids:
            return self.module(inputs, *targets, **kwargs)
        _targets, _kwargs = scatter_kwargs(targets, kwargs, self.device_ids, dim=self.dim)
        if len(self.device_ids) == 1:
            return self.module(inputs, *_targets[0], **_kwargs[0])
        autocast_if_needed = CriterionWithAutocast(module=self.module) if self.use_mixed_precision else self.module
        replicas = self.replicate(autocast_if_needed, self.device_ids[:len(inputs)])  # type: ignore

        input_tuples: List[Tuple[torch.Tensor, ...]] = [(i, *t) for i, t in zip(inputs, _targets)]
        outputs = torch.nn.parallel.parallel_apply(replicas, input_tuples, _kwargs)

        return gather(outputs, self.output_device, dim=self.dim)
Ejemplo n.º 11
0
Archivo: utils.py Proyecto: mjpyeon/DGL
    def forward(self, *inputs, init=False, **kwargs):
        if init:
            if self.device_ids:
                # -------- Here, we split the input tensor across GPUs
                inputs_ = inputs
                if not isinstance(inputs_, tuple):
                    inputs_ = (inputs_, )

                representation, _ = scatter_kwargs(inputs_, None,
                                                   self.device_ids, 0)
                self.replicas = self.replicate(
                    self.module, self.device_ids[:len(representation)])
                # ----
            else:
                representation = inputs
            return None, representation

        if not self.device_ids:
            return self.module(*inputs, **kwargs)
        # inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
        if len(self.device_ids) == 1:
            import ipdb
            ipdb.set_trace()
            return self.module(*inputs[0][0], **kwargs)

        kwargs = scatter(kwargs, self.device_ids) if kwargs else []
        #   if len(inputs) < len(kwargs):
        #      inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
        # elif len(kwargs) < len(inputs):
        #    kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
        kwargs = tuple(kwargs)
        outputs = self.parallel_apply(self.replicas, *inputs, kwargs)

        out1 = []
        out2 = []
        for i, tensor in enumerate(outputs):
            with torch.cuda.device(tensor[0].get_device()):
                # out_1[i] = torch.autograd.Variable(tensors[i])
                out1.append(outputs[i][0])
                out2.append(outputs[i][1])
        outputs = self.gather(out1, self.output_device)
        representation = out2
        return outputs, representation
Ejemplo n.º 12
0
    def scatter(self, inputs, kwargs, device_ids):
        try:
            params = kwargs.pop('params')
        except KeyError:
            return super(DataParallel, self).scatter(inputs, kwargs,
                                                     device_ids)

        inputs_, kwargs_ = scatter_kwargs(inputs,
                                          kwargs,
                                          device_ids,
                                          dim=self.dim)
        # Add params argument unchanged back in kwargs
        replicas = self._replicate_params(params,
                                          inputs_,
                                          device_ids,
                                          detach=not torch.is_grad_enabled())
        kwargs_ = tuple(
            dict(params=replica, **kwarg)
            for (kwarg, replica) in zip(kwargs_, replicas))
        return inputs_, kwargs_
Ejemplo n.º 13
0
 def scatter(self, inputs, kwargs, device_ids):
     return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
Ejemplo n.º 14
0
 def scatter(self, inputs, kwargs, device_ids):
     outputs = scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
     return outputs
Ejemplo n.º 15
0
 def scatter(self, inputs, kwargs, device_ids):
     #return my_scatter(inputs, target_gpus=device_ids)
     outputs = scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
     return outputs