Esempio n. 1
0
def parallel_beam_search(opt, model, batch, fields):
    device_ids = list(range(torch.cuda.device_count()))
    results = {}

    if len(device_ids) <= 1:
        beam_search(opt, model, batch.src, fields, results=results)
        return results[0]

    # 1. scatter input
    sources = scatter(batch.src, device_ids)
    targets = scatter(batch.tgt, device_ids)

    # 2. replicate model
    replicas = replicate(model, device_ids[:len(sources)])
    assert len(replicas) == len(sources) == len(targets)

    # 3. parallel apply
    threads = [
        threading.Thread(target=beam_search,
                         args=(opt, model, src, fields, idx, results))
        for idx, (model, src) in enumerate(zip(replicas, sources))
    ]

    for thread in threads:
        thread.start()
    for thread in threads:
        thread.join()

    return [h for i in range(len(replicas)) for h in results[i]]
Esempio n. 2
0
    def scatter(self, inputs, kwargs, device_ids):
        """
        len(inputs) = how many inputs the network takes
        len(inputs[0]) = #GPUs * mbs
        """
        final_inputs = []
        if len(inputs[0]) % len(device_ids) != 0:
            raise Exception(
                "Number of inputs must be a multiple of number of devices")

        minibatch_size = int(len(inputs[0]) / len(device_ids))
        for i, device in enumerate(device_ids):
            input_i = inputs[0][i * minibatch_size:(i + 1) * minibatch_size]
            if len(input_i) == 1:
                input_i = input_i[0]
            final_inputs += scatter(input_i, [device],
                                    self.dim) if inputs else []
        if len(device_ids) == 1:
            final_inputs = [final_inputs]
        final_kwargs = scatter(kwargs, device_ids,
                               self.moduledim) if kwargs else []
        if len(final_inputs) < len(final_kwargs):
            final_inputs.extend([
                () for _ in range(len(final_kwargs) - len(final_inputs))
            ])
        elif len(final_kwargs) < len(final_inputs):
            final_kwargs.extend(
                [{} for _ in range(len(final_inputs) - len(final_kwargs))])
        final_inputs = tuple(final_inputs)
        final_kwargs = tuple(final_kwargs)

        return final_inputs, final_kwargs
Esempio n. 3
0
    def scatter(self, inputs, kwargs, device_ids):
        """
        Scatters the inputs and kwargs to several GPUs (device_ids).

        Assumptions
        ===========
        len(inputs) = how many inputs the network takes
        len(inputs[0]) = #GPUs * mbs
        """
        final_inputs = []
        for i, device in enumerate(device_ids):
            input_i = inputs[0][i]
            final_inputs += scatter([input_i], [device],
                                    self.dim) if inputs else []
        final_kwargs = scatter(kwargs, device_ids,
                               self.moduledim) if kwargs else []
        if len(final_inputs) < len(final_kwargs):
            final_inputs.extend([
                () for _ in range(len(final_kwargs) - len(final_inputs))
            ])
        elif len(final_kwargs) < len(final_inputs):
            final_kwargs.extend(
                [{} for _ in range(len(final_inputs) - len(final_kwargs))])
        final_inputs = tuple(final_inputs)
        final_kwargs = tuple(final_kwargs)
        return final_inputs, final_kwargs
Esempio n. 4
0
 def scatter(self, input_list, kwargs, device_ids):
     inputs = []
     for input, gpu in zip(input_list, device_ids):
         inputs.extend(scatter(input, [gpu], dim=0))
     kwargs = scatter(kwargs, device_ids, dim=0) if kwargs else []
     if len(inputs) < len(kwargs):
         inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
     elif len(kwargs) < len(inputs):
         kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
     inputs = tuple(inputs)
     kwargs = tuple(kwargs)
     return inputs, kwargs
Esempio n. 5
0
    def forward(self, prediction, sample):

        # [value.cuda() for value in sample.values()]

        device_ids = list(range(torch.cuda.device_count()))
        sample_gpu = scatter(sample, device_ids)

        if len(device_ids) == 1:
            tloss, loss_dict = self.compute_loss_single_gpu(
                prediction, sample_gpu[0])
        else:

            if not self.threaded:
                losses = [
                    self.compute_loss_single_gpu(pred, gt)
                    for pred, gt in zip(prediction, sample_gpu)
                ]
            else:

                modules = [
                    self.compute_loss_single_gpu
                    for i in range(len(device_ids))
                ]  # NOQA

                inputs = [inp for inp in zip(prediction, sample_gpu)]

                losses = parallel_apply(modules, inputs)

            tloss, loss_dict = gather(losses, target_device=0)
            tloss = sum(tloss) / len(tloss)

            for key, value in loss_dict.items():
                loss_dict[key] = sum(value) / len(value)

        return tloss, loss_dict
Esempio n. 6
0
    def forward(self, predictions, sample):

        sample_gpu = scatter(sample, self.device_ids)

        if len(self.device_ids) == 1:
            tloss, loss_dict = self.loss(
                predictions, sample_gpu[0])
        else:

            if not self.threaded:
                losses = [
                    self.loss(pred, gt)
                    for pred, gt in zip(predictions, sample_gpu)]
            else:

                modules = [
                    self.loss for i in range(len(self.device_ids))]

                inputs = [inp for inp in zip(predictions, sample_gpu)]

                losses = parallel_apply(
                    modules, inputs)

            # TODO: make pretty.

            tloss, loss_dict = gather(losses, target_device=0)
            tloss = sum(tloss) / len(tloss)

            for key, value in loss_dict.items():
                loss_dict[key] = sum(value) / len(value)

        return tloss, loss_dict
        pass
Esempio n. 7
0
 def forward(self, inputs, **kwargs):
     kwargs = scatter(kwargs, self.device_ids[:len(inputs)], self.dim)
     if len(self.device_ids) == 1:
         return (self.module(*inputs[0], **kwargs[0]), )
     replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
     outputs = self.parallel_apply(replicas, inputs, kwargs)
     return outputs
Esempio n. 8
0
 def forward(self, inputs, targets, **kwargs):
     # input should be already scatterd
     # scattering the targets instead
     if not self.device_ids:
         return self.module(inputs, *targets, **kwargs)
     kwargs = scatter(kwargs, self.device_ids[:len(inputs)], self.dim)
     if len(self.device_ids) == 1:
         return self.module(inputs[0], targets[0], **kwargs[0])
     replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
     outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
     return self.gather(outputs, self.output_device)
Esempio n. 9
0
    def step_all_gpus():
        input_var = torch.autograd.Variable(input)
        target_var = torch.autograd.Variable(target)

        scattered_inputs = scatter((input_var, target_var), device_ids)
        for in_queue, input_target in zip(in_queues, scattered_inputs):
            in_queue.put(input_target)

        # Wait until all the kernels are enqueued
        for out_queue in out_queues:
            out_queue.get()
Esempio n. 10
0
        def gradient_func(
            forward_fn: Callable,
            inputs: Union[Tensor, Tuple[Tensor, ...]],
            target_ind: TargetType = None,
            additional_forward_args: Any = None,
        ) -> Tuple[Tensor, ...]:
            if self.device_ids is None:
                scattered_inputs = (inputs, )
            else:
                # scatter method does not have a precise enough return type in its
                # stub, so suppress the type warning.
                scattered_inputs = scatter(  # type:ignore
                    inputs, target_gpus=self.device_ids)

            scattered_inputs_dict = {
                scattered_input[0].device: scattered_input
                for scattered_input in scattered_inputs
            }

            with torch.autograd.set_grad_enabled(True):

                def layer_forward_hook(module, hook_inputs, hook_outputs=None):
                    device = _extract_device(module, hook_inputs, hook_outputs)
                    is_layer_tuple = (isinstance(hook_outputs, tuple) if
                                      hook_outputs is not None else isinstance(
                                          hook_inputs, tuple))
                    if is_layer_tuple:
                        return scattered_inputs_dict[device]
                    return scattered_inputs_dict[device][0]

                hook = None
                try:
                    if attribute_to_layer_input:
                        hook = self.layer.register_forward_pre_hook(
                            layer_forward_hook)
                    else:
                        hook = self.layer.register_forward_hook(
                            layer_forward_hook)

                    output = _run_forward(self.forward_func, tuple(),
                                          target_ind, additional_forward_args)
                finally:
                    if hook is not None:
                        hook.remove()

                assert output[0].numel() == 1, (
                    "Target not provided when necessary, cannot"
                    " take gradient with respect to multiple outputs.")
                # torch.unbind(forward_out) is a list of scalar tensor tuples and
                # contains batch_size * #steps elements
                grads = torch.autograd.grad(torch.unbind(output), inputs)
            return grads
Esempio n. 11
0
    def parallel_forward(self, inputs, **kwargs):
        """Multi-GPU Mult-size Evaluation

        Args:
            inputs: list of Tensors
        """
        inputs = [(input.unsqueeze(0).cuda(device),) for input, device in zip(inputs, self.device_ids)]
        replicas = self.replicate(self, self.device_ids[:len(inputs)])
        kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
        if len(inputs) < len(kwargs):
            inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
        elif len(kwargs) < len(inputs):
            kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
        outputs = self.parallel_apply(replicas, inputs, kwargs)
        return outputs
Esempio n. 12
0
        def layer_forward_func(*args):
            layer_length = args[-1]
            layer_input = args[:layer_length]
            original_inputs = args[layer_length:-1]

            device_ids = self.device_ids
            if device_ids is None:
                device_ids = getattr(self.forward_func, "device_ids", None)

            all_layer_inputs = {}
            if device_ids is not None:
                scattered_layer_input = scatter(layer_input,
                                                target_gpus=device_ids)
                for device_tensors in scattered_layer_input:
                    all_layer_inputs[device_tensors[0].device] = device_tensors
            else:
                all_layer_inputs[layer_input[0].device] = layer_input

            def forward_hook(module, inp, out=None):
                device = _extract_device(module, inp, out)
                is_layer_tuple = (isinstance(out, tuple) if out is not None
                                  else isinstance(inp, tuple))
                if device not in all_layer_inputs:
                    raise AssertionError(
                        "Layer input not placed on appropriate "
                        "device. If using a DataParallel model, either provide the "
                        "DataParallel model as forward_func or provide device ids"
                        " to the constructor.")
                if not is_layer_tuple:
                    return all_layer_inputs[device][0]
                return all_layer_inputs[device]

            hook = None
            try:
                if attribute_to_layer_input:
                    hook = self.layer.register_forward_pre_hook(forward_hook)
                else:
                    hook = self.layer.register_forward_hook(forward_hook)
                eval = _run_forward(self.forward_func,
                                    original_inputs,
                                    target=target)
            finally:
                if hook is not None:
                    hook.remove()
            return eval
Esempio n. 13
0
File: utils.py Progetto: mjpyeon/DGL
    def forward(self, *inputs, init=False, **kwargs):
        if init:
            if self.device_ids:
                # -------- Here, we split the input tensor across GPUs
                inputs_ = inputs
                if not isinstance(inputs_, tuple):
                    inputs_ = (inputs_, )

                representation, _ = scatter_kwargs(inputs_, None,
                                                   self.device_ids, 0)
                self.replicas = self.replicate(
                    self.module, self.device_ids[:len(representation)])
                # ----
            else:
                representation = inputs
            return None, representation

        if not self.device_ids:
            return self.module(*inputs, **kwargs)
        # inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
        if len(self.device_ids) == 1:
            import ipdb
            ipdb.set_trace()
            return self.module(*inputs[0][0], **kwargs)

        kwargs = scatter(kwargs, self.device_ids) if kwargs else []
        #   if len(inputs) < len(kwargs):
        #      inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
        # elif len(kwargs) < len(inputs):
        #    kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
        kwargs = tuple(kwargs)
        outputs = self.parallel_apply(self.replicas, *inputs, kwargs)

        out1 = []
        out2 = []
        for i, tensor in enumerate(outputs):
            with torch.cuda.device(tensor[0].get_device()):
                # out_1[i] = torch.autograd.Variable(tensors[i])
                out1.append(outputs[i][0])
                out2.append(outputs[i][1])
        outputs = self.gather(out1, self.output_device)
        representation = out2
        return outputs, representation
Esempio n. 14
0
        def gradient_func(forward_fn,
                          inputs,
                          target_ind=None,
                          additional_forward_args=None):
            if self.device_ids is None:
                scattered_inputs = (inputs, )
            else:
                scattered_inputs = scatter(inputs, target_gpus=self.device_ids)

            scattered_inputs_dict = {
                scattered_input[0].device: scattered_input
                for scattered_input in scattered_inputs
            }

            with torch.autograd.set_grad_enabled(True):

                def layer_forward_hook(module, hook_inputs, hook_outputs=None):
                    device = _extract_device(module, hook_inputs, hook_outputs)
                    if is_layer_tuple:
                        return scattered_inputs_dict[device]
                    return scattered_inputs_dict[device][0]

                if attribute_to_layer_input:
                    hook = self.layer.register_forward_pre_hook(
                        layer_forward_hook)
                else:
                    hook = self.layer.register_forward_hook(layer_forward_hook)

                output = _run_forward(
                    self.forward_func,
                    additional_forward_args,
                    target_ind,
                )
                hook.remove()
                assert output[0].numel() == 1, (
                    "Target not provided when necessary, cannot"
                    " take gradient with respect to multiple outputs.")
                # torch.unbind(forward_out) is a list of scalar tensor tuples and
                # contains batch_size * #steps elements
                grads = torch.autograd.grad(torch.unbind(output), inputs)
            return grads
Esempio n. 15
0
def calc_gradient_penalty(netD, real_data, fake_data, data_parallel):
    # Sample random number for each sample in a batch
    bs, c, h, w, d = real_data.size()
    
    alpha = torch.rand(bs, 1, 1, 1, 1)
    alpha = alpha.expand(bs, c, h, w, d).contiguous()
    alpha = alpha.to(device = real_data.get_device())
    
    # Generate interpolation sample - these sample do not have to visually make sense. Just
    # a regularization to make sure the discriminator is smooth given whatever input
    interpolates = alpha * real_data.detach() + ((1.0 - alpha) * fake_data.detach())
    interpolates.requires_grad_(True) # Default variable from detach has requires_grad_ disabled
    
    # Scatter the interpolates across gpus
    if data_parallel:
        interpolates_scatter = scatter(interpolates, netD.device_ids) 
        interpolates_scatter = tuple(interpolates_scatter)
        
        # wrap one more layer around
        # wrap one more layer around to fit discriminator input format
        interpolates_scatter = ( (interp,) for interp in interpolates_scatter ) 
        disc_interpolates = netD(interpolates_scatter).mean() # Pass through discriminator. The discriminator does the gather
    else:
        disc_interpolates = netD(interpolates).mean() # Pass through discriminator. The discriminator does the gather
        

    # The create_graph=True, retain_graph=True means constructing graph for the gradient operation
    # and allow higher order operation. We need to get gradient of gradient so these are set to true.
    # grad_outputs seems to be default value of gradient replacing None value if any. 
    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).to(device = disc_interpolates.get_device()),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]

    gradients = gradients.view(gradients.size(0), -1)  
    
    # It seems that here they arbitarily determine the discriminator should have gradient size of 1.                           
    # 1 as Lipschitz constant.
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty, interpolates
Esempio n. 16
0
    def forward(self, inputs, **kwargs):
        """
    inputs should be a list of dgl.NodeFlows when multi-gpus is enabled.
    The length of inputs should be equal (or less) to device num.
    Each element in inputs should be an instance of nodeflow
    """
        if not self.device_ids:
            return self.module(*inputs, **kwargs)

        for t in chain(self.module.parameters(), self.module.buffers()):
            if t.device != self.src_device_obj:
                raise RuntimeError(
                    "module must have its parameters and buffers "
                    "on device {} (device_ids[0]) but found one of "
                    "them on device: {}".format(self.src_device_obj, t.device))

        if not isinstance(inputs, list):
            inputs = [inputs]
        if len(self.device_ids) < len(inputs):
            raise RuntimeError(
                "device num [{}] is not equal to inputs length [{}]".format(
                    len(self.device_ids), len(inputs)))
        # replicate kwargs
        kwargs = scatter(kwargs, self.device_ids[:len(inputs)], 0)
        if len(self.device_ids) == 1:
            device = torch.device(0) if self.use_cuda else torch.device('cpu')
            inputs[0].copy_from_parent(ctx=device)
            return self.module(inputs[0])
        elif isinstance(inputs[0], NodeFlow):
            # copy inputs from its parent graph (should reside in cuda:0)
            # better way for small graphs to do this is to replica parent features
            # to all gpus and load from its own gpu
            for device_id in range(len(inputs)):
                device = torch.device(self.device_ids[device_id])
                inputs[device_id].copy_from_parent(ctx=device)
        replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
        outputs = self.parallel_apply(replicas, inputs, kwargs)
        return self.gather(outputs, self.output_device)
Esempio n. 17
0
        def layer_forward_func(*args):
            layer_input = args[0]
            original_inputs = args[1:]

            device_ids = self.device_ids
            if device_ids is None:
                device_ids = getattr(self.forward_func, "device_ids", None)

            all_layer_inputs = {}
            if device_ids is not None:
                scattered_layer_input = scatter(layer_input,
                                                target_gpus=device_ids)
                for tensor in scattered_layer_input:
                    all_layer_inputs[tensor.device] = tensor
            else:
                all_layer_inputs[layer_input.device] = layer_input

            def forward_hook(module, inp, out=None):
                device = _extract_device(module, inp, out)
                if device not in all_layer_inputs:
                    raise AssertionError(
                        "Layer input not placed on appropriate "
                        "device. If using a DataParallel model, either provide the "
                        "DataParallel model as forward_func or provide device ids"
                        " to the constructor.")
                return all_layer_inputs[device]

            if attribute_to_layer_input:
                hook = self.layer.register_forward_pre_hook(forward_hook)
            else:
                hook = self.layer.register_forward_hook(forward_hook)
            eval = _run_forward(self.forward_func,
                                original_inputs,
                                target=target)
            hook.remove()
            return eval
Esempio n. 18
0
    def parallel_forward(self, dense_x, lS_o, lS_i):
        ### prepare model (overwrite) ###
        # WARNING: # of devices must be >= batch size in parallel_forward call
        batch_size = dense_x.size()[0]
        ndevices = min(self.ndevices, batch_size, len(self.emb_l))
        device_ids = range(ndevices)
        # WARNING: must redistribute the model if mini-batch size changes(this is common
        # for last mini-batch, when # of elements in the dataset/batch size is not even
        if self.parallel_model_batch_size != batch_size:
            self.parallel_model_is_not_prepared = True

        if self.sync_dense_params or self.parallel_model_is_not_prepared:
            # replicate mlp (data parallelism)
            self.bot_l_replicas = replicate(self.bot_l, device_ids)
            self.top_l_replicas = replicate(self.top_l, device_ids)
            # distribute embeddings (model parallelism)
            t_list = []
            for k, emb in enumerate(self.emb_l):
                d = torch.device("cuda:" + str(k % ndevices))
                emb.to(d)
                t_list.append(emb.to(d))
            self.emb_l = nn.ModuleList(t_list)
            self.parallel_model_batch_size = batch_size
            self.parallel_model_is_not_prepared = False

        ### prepare input (overwrite) ###
        # scatter dense features (data parallelism)
        # print(dense_x.device)
        dense_x = scatter(dense_x, device_ids, dim=0)
        # distribute sparse features (model parallelism)
        if (len(self.emb_l) != len(lS_o)) or (len(self.emb_l) != len(lS_i)):
            sys.exit(
                "ERROR: corrupted model input detected in parallel_forward call"
            )

        t_list = []
        i_list = []
        for k, _ in enumerate(self.emb_l):
            d = torch.device("cuda:" + str(k % ndevices))
            t_list.append(lS_o[k].to(d))
            i_list.append(lS_i[k].to(d))
        lS_o = t_list
        lS_i = i_list

        ### compute results in parallel ###
        # bottom mlp
        # WARNING: Note that the self.bot_l is a list of bottom mlp modules
        # that have been replicated across devices, while dense_x is a tuple of dense
        # inputs that has been scattered across devices on the first (batch) dimension.
        # The output is a list of tensors scattered across devices according to the
        # distribution of dense_x.
        x = parallel_apply(self.bot_l_replicas, dense_x, None, device_ids)
        # debug prints
        # print(x)

        # embeddings
        ly = self.apply_emb(lS_o, lS_i, self.emb_l)
        # debug prints
        # print(ly)

        # butterfly shuffle (implemented inefficiently for now)
        # WARNING: Note that at this point we have the result of the embedding lookup
        # for the entire batch on each device. We would like to obtain partial results
        # corresponding to all embedding lookups, but part of the batch on each device.
        # Therefore, matching the distribution of output of bottom mlp, so that both
        # could be used for subsequent interactions on each device.
        if len(self.emb_l) != len(ly):
            sys.exit(
                "ERROR: corrupted intermediate result in parallel_forward call"
            )

        t_list = []
        for k, _ in enumerate(self.emb_l):
            d = torch.device("cuda:" + str(k % ndevices))
            y = scatter(ly[k], device_ids, dim=0)
            t_list.append(y)
        # adjust the list to be ordered per device
        ly = list(map(lambda y: list(y), zip(*t_list)))
        # debug prints
        # print(ly)

        # interactions
        z = []
        for k in range(ndevices):
            zk = self.interact_features(x[k], ly[k])
            z.append(zk)
        # debug prints
        # print(z)

        # top mlp
        # WARNING: Note that the self.top_l is a list of top mlp modules that
        # have been replicated across devices, while z is a list of interaction results
        # that by construction are scattered across devices on the first (batch) dim.
        # The output is a list of tensors scattered across devices according to the
        # distribution of z.
        p = parallel_apply(self.top_l_replicas, z, None, device_ids)

        ### gather the distributed results ###
        p0 = gather(p, self.output_d, dim=0)

        # clamp output if needed
        if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
            z0 = torch.clamp(p0,
                             min=self.loss_threshold,
                             max=(1.0 - self.loss_threshold))
        else:
            z0 = p0

        return z0
 def scatter(self, inputs, device_ids):
     return scatter(inputs, device_ids)
Esempio n. 20
0
def _scatter_kwargs(kwargs, target_gpus, dim=0):
    r"""Scatter with support for kwargs dictionary"""
    kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
    kwargs = tuple(kwargs)
    return kwargs
Esempio n. 21
0
 def _move(self, t: TensorOrList) -> TensorOrList:
     """ move (nested) tensors to the assigned devices """
     return scatter(t, target_gpus=self.indices)[0]
Esempio n. 22
0
        def gradient_func(
            forward_fn: Callable,
            inputs: Union[Tensor, Tuple[Tensor, ...]],
            target_ind: TargetType = None,
            additional_forward_args: Any = None,
        ) -> Tuple[Tensor, ...]:
            if self.device_ids is None or len(self.device_ids) == 0:
                scattered_inputs = (inputs, )
            else:
                # scatter method does not have a precise enough return type in its
                # stub, so suppress the type warning.
                scattered_inputs = scatter(  # type:ignore
                    inputs, target_gpus=self.device_ids)

            scattered_inputs_dict = {
                scattered_input[0].device: scattered_input
                for scattered_input in scattered_inputs
            }

            with torch.autograd.set_grad_enabled(True):

                def layer_forward_hook(module,
                                       hook_inputs,
                                       hook_outputs=None,
                                       layer_idx=0):
                    device = _extract_device(module, hook_inputs, hook_outputs)
                    is_layer_tuple = (
                        isinstance(hook_outputs, tuple)
                        # hook_outputs is None if attribute_to_layer_input == True
                        if hook_outputs is not None else isinstance(
                            hook_inputs, tuple))

                    if is_layer_tuple:
                        return scattered_inputs_dict[device][
                            num_outputs_cumsum[layer_idx]:num_outputs_cumsum[
                                layer_idx + 1]]

                    return scattered_inputs_dict[device][
                        num_outputs_cumsum[layer_idx]]

                hooks = []
                try:

                    layers = self.layer
                    if not isinstance(layers, list):
                        layers = [self.layer]

                    for layer_idx, layer in enumerate(layers):
                        hook = None
                        # TODO:
                        # Allow multiple attribute_to_layer_input flags for
                        # each layer, i.e. attribute_to_layer_input[layer_idx]
                        if attribute_to_layer_input:
                            hook = layer.register_forward_pre_hook(
                                functools.partial(layer_forward_hook,
                                                  layer_idx=layer_idx))
                        else:
                            hook = layer.register_forward_hook(
                                functools.partial(layer_forward_hook,
                                                  layer_idx=layer_idx))

                        hooks.append(hook)

                    output = _run_forward(self.forward_func, tuple(),
                                          target_ind, additional_forward_args)
                finally:
                    for hook in hooks:
                        if hook is not None:
                            hook.remove()

                assert output[0].numel() == 1, (
                    "Target not provided when necessary, cannot"
                    " take gradient with respect to multiple outputs.")
                # torch.unbind(forward_out) is a list of scalar tensor tuples and
                # contains batch_size * #steps elements
                grads = torch.autograd.grad(torch.unbind(output), inputs)
            return grads