Exemplo n.º 1
0
    def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
                "Checkpointing is not compatible with .grad(), please use .backward() if possible"
            )
        inputs = ctx.saved_tensors
        # Stash the surrounding rng state, and mimic the state that was
        # present at this time during forward.  Restore the surrouding state
        # when we're done.
        rng_devices = []
        if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
            rng_devices = ctx.fwd_gpu_devices
        bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
        with torch.random.fork_rng(devices=rng_devices,
                                   enabled=ctx.preserve_rng_state):
            if ctx.preserve_rng_state:
                torch.set_rng_state(ctx.fwd_cpu_state)
                if ctx.had_cuda_in_fwd:
                    set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
                get_cuda_rng_tracker().set_states(
                    ctx.fwd_cuda_rng_state_tracker)
            detached_inputs = detach_variable(inputs)
            with torch.enable_grad():
                outputs = ctx.run_function(*detached_inputs)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs, )
        get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
        torch.autograd.backward(outputs, args)
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
                      for inp in detached_inputs)
        return (None, None) + grads
Exemplo n.º 2
0
    def backward_pass(self, y1, y2, dy1, dy2, attn_mask=None):
        """
        :param y1:
        :param y2:
        :param dy1:
        :param dy2:
        :param attn_mask:
        :return:
        """
        """Implementation of the backward pass for reversible transformer encoder"""

        with torch.enable_grad():
            y1.requires_grad = True
            with torch.random.fork_rng(devices=self.ffn_gpu_devices,
                                       enabled=True):
                torch.set_rng_state(self.ffn_cpu_state)
                set_device_states(self.ffn_gpu_devices, self.ffn_gpu_states)

                z2 = self.feedforward(y1)

            # res_hidden_states.backward(grad_hidden_states, retain_graph=True)
            torch.autograd.backward(z2, dy2)

        with torch.no_grad():
            # restore X2 = Y2 - G(Y1)
            x2 = y2 - z2
            del z2, y2

            # DX1 = DY1 + Y1.grad
            dx1 = dy1 + y1.grad
            del dy1
            y1.grad = None

        with torch.enable_grad():

            x2.requires_grad = True

            with torch.random.fork_rng(devices=self.attn_gpu_devices,
                                       enabled=True):
                torch.set_rng_state(self.attn_cpu_state)
                set_device_states(self.attn_gpu_devices, self.attn_gpu_states)

                z1, _, _ = self.self_attn(x2, attn_mask)

            z1.backward(dx1)

        with torch.no_grad():
            # restore X1 = Y1 - F(X2)
            x1 = y1 - z1
            del y1, z1

            dx2 = dy2 + x2.grad
            x2.grad = None
            del dy2
            x2 = x2.detach()

        return x1, x2, dx1, dx2
Exemplo n.º 3
0
    def backward(ctx, *args):
        print('Bwd', args)
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
                "Checkpointing is not compatible with .grad() or when an `inputs` parameter"
                " is passed to .backward(). Please use .backward() and do not pass its `inputs`"
                " argument.")
        # Copy the list to avoid modifying original list.
        inputs = list(ctx.inputs)
        tensor_indices = ctx.tensor_indices
        tensors = ctx.saved_tensors

        # Fill in inputs with appropriate saved tensors.
        for i, idx in enumerate(tensor_indices):
            inputs[idx] = tensors[i]

        # Stash the surrounding rng state, and mimic the state that was
        # present at this time during forward.  Restore the surrounding state
        # when we're done.
        rng_devices = []
        if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
            rng_devices = ctx.fwd_gpu_devices
        with torch.random.fork_rng(devices=rng_devices,
                                   enabled=ctx.preserve_rng_state):
            if ctx.preserve_rng_state:
                torch.set_rng_state(ctx.fwd_cpu_state)
                if ctx.had_cuda_in_fwd:
                    chkpt.set_device_states(ctx.fwd_gpu_devices,
                                            ctx.fwd_gpu_states)
            detached_inputs = chkpt.detach_variable(tuple(inputs))
            with torch.enable_grad(), torch.cuda.amp.autocast(
                    ctx.had_autocast_in_fwd):
                outputs = ctx.run_function(*detached_inputs)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs, )

        # run backward() with only tensor that requires grad
        outputs_with_grad = []
        args_with_grad = []
        for i in range(len(outputs)):
            if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
                outputs_with_grad.append(outputs[i])
                args_with_grad.append(args[i])
        if len(outputs_with_grad) == 0:
            raise RuntimeError("none of output has requires_grad=True,"
                               " this checkpoint() is not necessary")
        torch.autograd.backward(outputs_with_grad, args_with_grad)

        # ctx.callback()

        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
                      for inp in detached_inputs)
        # print(grads)
        return (None, None, None) + grads
Exemplo n.º 4
0
    def backward_pass(self, y1, y2, dy1, dy2, pos, attn_mask=None):
        """
        :param pos:
        :param y1:
        :param y2:
        :param dy1:
        :param dy2:
        :param attn_mask:
        :return:
        """
        """Implementation of the backward pass for reversible transformer encoder"""

        with torch.enable_grad():
            y1.requires_grad = True
            with torch.random.fork_rng(devices=self.ffn_gpu_devices,
                                       enabled=True):
                torch.set_rng_state(self.ffn_cpu_state)
                set_device_states(self.ffn_gpu_devices, self.ffn_gpu_states)

                gy1 = self.feedforward(y1)

            gy1.backward(dy2)

        with torch.no_grad():
            # restore X2 = Y2 - G(Y1)
            x2 = y2 - gy1
            del gy1, y2

            dx1 = dy1 + y1.grad
            del dy1
            y1.grad = None

        with torch.enable_grad():
            x2.requires_grad = True

            with torch.random.fork_rng(devices=self.attn_gpu_devices,
                                       enabled=True):
                torch.set_rng_state(self.attn_cpu_state)
                set_device_states(self.attn_gpu_devices, self.attn_gpu_states)

                fx2, _, = self.self_attn(x2, pos, key_padding_mask=attn_mask)

            fx2.backward(dx1)

        with torch.no_grad():
            # restore X1 = Y1 - F(X2)
            x1 = y1 - fx2
            del y1, fx2

            dx2 = dy2 + x2.grad
            x2.grad = None
            del dy2

        return x1, x2, dx1, dx2
Exemplo n.º 5
0
    def forward(self, *args, record_rng=False, set_rng=False, **kwargs):
        if record_rng:
            self.record_rng(*args)

        if not set_rng:
            return self.net(*args, **kwargs)

        rng_devices = []
        if self.cuda_in_fwd:
            rng_devices = self.gpu_devices

        with torch.random.fork_rng(devices=rng_devices, enabled=True):
            torch.set_rng_state(self.cpu_state)
            if self.cuda_in_fwd:
                set_device_states(self.gpu_devices, self.gpu_states)
            return self.net(*args, **kwargs)
Exemplo n.º 6
0
    def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
                "Checkpointing is not compatible with .grad(), please use .backward() if possible"
            )
        inputs = ctx.saved_tensors
        rng_devices = []
        if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
            rng_devices = ctx.fwd_gpu_devices
        with torch.random.fork_rng(devices=rng_devices,
                                   enabled=ctx.preserve_rng_state):
            if ctx.preserve_rng_state:
                torch.set_rng_state(ctx.fwd_cpu_state)
                if ctx.had_cuda_in_fwd:
                    set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
            detached_inputs = detach_variable(inputs)
            with torch.enable_grad():
                outputs = ctx.run_function(*detached_inputs)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs, )

        #
        # Skip None items and tensors which requires_grad are False when doing backward pass
        #
        backward_outputs = []
        backward_args = []
        for o, a in zip(outputs, args):
            if o is not None and o.requires_grad:
                backward_outputs.append(o)
                backward_args.append(a)
        torch.autograd.backward(backward_outputs, backward_args)

        # torch.autograd.backward(outputs, args)
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
                      for inp in detached_inputs)
        return (None, None) + grads
Exemplo n.º 7
0
    def backward(ctx, *grad_output):
        with torch.enable_grad():
            detached_inputs = [
                detach_variable(v.to(device, non_blocking=True))
                for v, device in zip(ctx.saved_tensors, ctx.devices)
            ]
            state = nest.pack_sequence_as(ctx.structure, detached_inputs)
            next_state = state
            rng_devices = ctx.fwd_gpu_devices
            with torch.random.fork_rng(devices=rng_devices, enabled=True):
                torch.set_rng_state(ctx.fwd_cpu_state)
                set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
                for _ in range(ctx.block_size):
                    next_state = ctx.body_fn(next_state)
        next_state = nest.flatten(next_state)

        next_state, grad_output = zip(
            *
            [sg for sg in zip(next_state, grad_output) if sg[0].requires_grad])
        torch.autograd.backward(next_state, grad_output)

        return (None, None, None) + tuple(
            inp.grad if isinstance(inp, torch.Tensor) else None
            for inp in detached_inputs)
Exemplo n.º 8
0
    def backward(ctx, *grad_output):
        '''
        :param ctx: context, like self
        :param grad_output: the last module backward output
        :return: grad output, require number of outputs is the number of forward parameters -1, because ctx is not included
        '''

        # Get output that saved by forward function
        bak_outputs = ctx.saved_tensors
        with torch.no_grad():

            # Start from the last module
            for m in list(ctx.rev_block_stack)[::-1]:

                if ctx.preserve_rng_state:
                    # Restore rng state
                    rng_devices = []
                    if ctx.had_cuda_in_fwd:
                        fwd_gpu_devices, fwd_gpu_states = ctx.cuda_status_stack.pop(
                            -1)
                        rng_devices = fwd_gpu_devices

                    fwd_cpu_state = ctx.status_stack.pop(-1)

                    with torch.random.fork_rng(devices=rng_devices,
                                               enabled=ctx.preserve_rng_state):
                        torch.set_rng_state(fwd_cpu_state)
                        if ctx.had_cuda_in_fwd:
                            set_device_states(fwd_gpu_devices, fwd_gpu_states)
                        # Restore input from output
                        inputs = m.inverse(*bak_outputs)
                    # Detach variables from graph
                    # Fix some problem in pytorch1.6
                    inputs = [t.detach().clone() for t in inputs]

                    # You need to set requires_grad to True to differentiate the input.
                    # The derivative is the input of the next backpass function.
                    # This is how grad_output comes.
                    for inp in inputs:
                        inp.requires_grad = True
                    # run backward for each sub-module
                    with torch.enable_grad():
                        # Restore rng state again
                        with torch.random.fork_rng(
                                devices=rng_devices,
                                enabled=ctx.preserve_rng_state):
                            torch.set_rng_state(fwd_cpu_state)
                            if ctx.had_cuda_in_fwd:
                                set_device_states(fwd_gpu_devices,
                                                  fwd_gpu_states)
                            outputs = m(*inputs)

                        if isinstance(outputs, torch.Tensor):
                            outputs = (outputs, )
                        torch.autograd.backward(outputs, grad_output)

                        grad_output = tuple(
                            inp.grad if isinstance(inp, torch.Tensor) else inp
                            for inp in inputs)
                        bak_outputs = inputs

                else:
                    # Don't save rng state
                    # Restore input from output
                    inputs = m.inverse(*bak_outputs)
                    # Detach variables from graph
                    # Fix some problem in pytorch1.6
                    inputs = [t.detach().clone() for t in inputs]
                    for inp in inputs:
                        inp.requires_grad = True
                    # backward for each local and small graph
                    with torch.enable_grad():
                        outputs = m(*inputs)
                    if isinstance(outputs, torch.Tensor):
                        outputs = (outputs, )
                    torch.autograd.backward(outputs, grad_output)
                    grad_output = tuple(
                        inp.grad if isinstance(inp, torch.Tensor) else inp
                        for inp in inputs)
                    bak_outputs = inputs
        return (None, None) + grad_output
Exemplo n.º 9
0
    def backward_pass(self,
                      y1,
                      y2,
                      dy1,
                      dy2,
                      pos,
                      context,
                      mask_tgt,
                      mask_src,
                      incremental=False,
                      incremental_cache=None,
                      reuse_source=False):
        """
        :param pos:
        :param y1
        :param y2
        :param dy1: dL/dX2
        :param dy2: dL/dY2
        :param context:
        :param mask_tgt:
        :param mask_src:
        :param incremental:
        :param incremental_cache:
        :param reuse_source:
        :return:
        """

        # if not self.forward_coin:  # this layer was skipped, just return
        #     return y1, y2, dy1, dy2, None

        # first block: recompute the ffn transition function
        with torch.enable_grad():
            y1.requires_grad = True

            with torch.random.fork_rng(devices=self.ffn2_gpu_devices,
                                       enabled=True):
                torch.set_rng_state(self.ffn2_cpu_state)
                set_device_states(self.ffn2_gpu_devices, self.ffn2_gpu_states)

                k_y1 = self.feed_forward_second(y1)

            k_y1.backward(dy2)

        with torch.no_grad():
            z2 = y2 - k_y1
            del k_y1, y2

            # Dz1 = DY1 + Y1.grad
            dz1 = dy1 + y1.grad
            del dy1
            y1.grad = None

        # second block
        with torch.enable_grad():
            z2.requires_grad = True
            context.requires_grad = True

            with torch.random.fork_rng(devices=self.src_attn_gpu_devices,
                                       enabled=True):
                torch.set_rng_state(self.src_attn_cpu_state)
                set_device_states(self.src_attn_gpu_devices,
                                  self.src_attn_gpu_states)

                # if not self.ignore_source:
                h_z2, _ = self.src_attention(
                    z2,
                    context,
                    mask_src,
                    incremental=incremental,
                    incremental_cache=incremental_cache)

            # torch.autograd.backward(h_z2, dz1)
            h_z2.backward(dz1)

        with torch.no_grad():
            z1 = y1 - h_z2
            del y1, h_z2

            dz2 = dy2 + z2.grad
            z2.grad = None
            del dy2

            grad_context = context.grad
            del context.grad

        # third block
        with torch.enable_grad():
            z1.requires_grad = True

            with torch.random.fork_rng(devices=self.ffn1_gpu_devices,
                                       enabled=True):
                torch.set_rng_state(self.ffn1_cpu_state)
                set_device_states(self.ffn1_gpu_devices, self.ffn1_gpu_states)

                g_z1 = self.feed_forward_first(z1)

            # torch.autograd.backward(g_z1, dz2)
            g_z1.backward(dz2)
        #
        with torch.no_grad():
            x2 = z2 - g_z1
            del z2, g_z1

            dx1 = dz1 + z1.grad

            z1.grad = None
            del dz1

        # fourth block
        with torch.enable_grad():
            x2.requires_grad = True

            with torch.random.fork_rng(devices=self.attn_gpu_devices,
                                       enabled=True):
                torch.set_rng_state(self.attn_cpu_state)
                set_device_states(self.attn_gpu_devices, self.attn_gpu_states)

                f_x2, _, = self.self_attention(
                    x2,
                    pos,
                    key_padding_mask=None,
                    attn_mask=mask_tgt,
                    incremental=incremental,
                    incremental_cache=incremental_cache)

            f_x2.backward(dx1)

        with torch.no_grad():
            x1 = z1 - f_x2
            del z1, f_x2

            dx2 = dz2 + x2.grad
            x2.grad = None

            del dz2

        return x1, x2, dx1, dx2, grad_context
Exemplo n.º 10
0
    def backward(ctx, *output_grads):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")

        require_grad_indices = list()
        non_grad_indices = list()
        for i in range(len(ctx.input_tensors)):
            temp = ctx.input_tensors[i]
            ctx.input_tensors[i] = temp.detach()
            ctx.input_tensors[i].requires_grad = temp.requires_grad  # temp.requires_grad
            # require_grad_list[i] = temp.requires_grad
            if temp.requires_grad:
                require_grad_indices.append(i)
            else:
                non_grad_indices.append(i)
        # Stash the surrounding rng state, and mimic the state that was
        # present at this time during forward.  Restore the surrounding state
        # when we're done.
        rng_devices = []
        if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
            rng_devices = ctx.fwd_gpu_devices
        with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
            if ctx.preserve_rng_state:
                torch.set_rng_state(ctx.fwd_cpu_state)
                if ctx.had_cuda_in_fwd:
                    set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)

            with torch.enable_grad(), torch.cuda.amp.autocast(ctx.had_autocast_in_fwd):
                output_tensors = ctx.run_function(*ctx.input_tensors)

        # if isinstance(outputs, torch.Tensor):
        #     outputs = (outputs,)

        # # run backward() with only tensor that requires grad
        # outputs_with_grad = []
        # args_with_grad = []
        # for i in range(len(outputs)):
        #     if outputs[i].requires_grad:
        #         outputs_with_grad.append(outputs[i])
        #         args_with_grad.append(args[i])
        # if len(outputs_with_grad) == 0:
        #     raise RuntimeError(
        #         "none of output has requires_grad=True,"
        #         " this checkpoint() is not necessary")
        # torch.autograd.backward(outputs_with_grad, args_with_grad)
        # grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
        #               for inp in detached_inputs)

        input_tensors_with_grad = list()
        for i in range(len(ctx.input_tensors)):

            if i in require_grad_indices:
                input_tensors_with_grad.append(ctx.input_tensors[i])

        input_grads = torch.autograd.grad(output_tensors, input_tensors_with_grad, output_grads, allow_unused=True)

        return_input_grads = list()
        j = 0

        for i in range(len(ctx.input_tensors)):

            if i in require_grad_indices:
                return_input_grads.append(input_grads[j])
                j = j + 1
            else:
                return_input_grads.append(None)

        return (None, None) + tuple(return_input_grads)
Exemplo n.º 11
0
    def backward_pass(self,
                      y1,
                      y2,
                      dy1,
                      dy2,
                      context,
                      mask_tgt,
                      mask_src,
                      incremental=False,
                      incremental_cache=None,
                      reuse_source=False):
        """
        :param y1
        :param y2
        :param dy1: dL/dX2
        :param dy2: dL/dY2
        :param context:
        :param mask_tgt:
        :param mask_src:
        :param incremental:
        :param incremental_cache:
        :param reuse_source:
        :return:
        """

        # if not self.forward_coin:  # this layer was skipped, just return
        #     return y1, y2, dy1, dy2, None

        # first block: recompute the ffn transition function
        with torch.enable_grad():
            y1.requires_grad = True

            with torch.random.fork_rng(devices=self.ffn_gpu_devices,
                                       enabled=True):
                torch.set_rng_state(self.ffn_cpu_state)
                set_device_states(self.ffn_gpu_devices, self.ffn_gpu_states)

                g_y1 = self.feed_forward(y1)

            torch.autograd.backward(g_y1, dy2)

        with torch.no_grad():
            # restore X2 = Y2 - G(Y1)
            x2 = y2 - g_y1

            # DX1 = DY1 + Y1.grad
            dx1 = dy1 + y1.grad
            del y2, g_y1, dy1
            y1.grad = None

        # second block
        with torch.enable_grad():

            x2.requires_grad = True
            context.requires_grad = True

            with torch.random.fork_rng(devices=self.attn_gpu_devices,
                                       enabled=True):
                torch.set_rng_state(self.attn_cpu_state)
                set_device_states(self.attn_gpu_devices, self.attn_gpu_states)

                f_x2, coverage, incremental_cache = self.self_attention(
                    x2,
                    mask_tgt,
                    incremental=incremental,
                    incremental_cache=incremental_cache)

                z = f_x2

                # if not self.ignore_source:
                f_x2, _, _ = self.src_attention(
                    f_x2,
                    context,
                    mask_src,
                    incremental=incremental,
                    incremental_cache=incremental_cache)

                f_x2 = f_x2 + z

            torch.autograd.backward(f_x2, dx1)

        with torch.no_grad():
            # restore X1 = Y1 - F(X2)
            x1 = y1 - f_x2
            del y1, f_x2

            dx2 = dy2 + x2.grad
            x2.grad = None
            del dy2
            x2 = x2.detach()
            grad_context = context.grad
            del context.grad

        # # third block
        # with torch.enable_grad():
        #     x2.requires_grad = True
        #
        #     with torch.random.fork_rng(devices=self.attn_gpu_devices, enabled=True):
        #         torch.set_rng_state(self.attn_cpu_state)
        #         set_device_states(self.attn_gpu_devices, self.attn_gpu_states)
        #
        #         f_x2, _, _ = self.self_attention(x2, mask_tgt)
        #
        #         if self.training and self.death_rate > 0:
        #             f_x2 = f_x2 / (1 - self.death_rate)
        #
        #     torch.autograd.backward(f_x2, dz1)
        #
        # with torch.no_grad():
        #     # restore X1 = Y1 - F(X2)
        #     x1 = z1 - f_x2
        #
        #     dx1 = dz1
        #     dx2 = dy2 + x2.grad
        #     del z1, f_x2
        #
        #     x2.grad = None
        #     x2 = x2.detach()

        return x1, x2, dx1, dx2, grad_context
Exemplo n.º 12
0
    def backward(ctx, *grad_outputs):  # pragma: no cover
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
                "InvertibleCheckpointFunction is not compatible with .grad(), please use .backward() if possible"
            )
        # retrieve input and output tensor nodes
        if len(ctx.outputs) == 0:
            raise RuntimeError(
                "Trying to perform backward on the InvertibleCheckpointFunction for more than "
                "{} times! Try raising `num_bwd_passes` by one.".format(
                    ctx.num_bwd_passes))
        inputs = ctx.inputs.pop()
        outputs = ctx.outputs.pop()

        # recompute input if necessary
        if not ctx.keep_input:
            # Stash the surrounding rng state, and mimic the state that was
            # present at this time during forward.  Restore the surrounding state
            # when we're done.
            rng_devices = []
            if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
                rng_devices = ctx.fwd_gpu_devices
            with torch.random.fork_rng(devices=rng_devices,
                                       enabled=ctx.preserve_rng_state):
                if ctx.preserve_rng_state:
                    torch.set_rng_state(ctx.fwd_cpu_state)
                    if ctx.had_cuda_in_fwd:
                        set_device_states(ctx.fwd_gpu_devices,
                                          ctx.fwd_gpu_states)
                # recompute input
                with torch.no_grad():
                    # edge_index and edge_emb
                    inputs_inverted = ctx.fn_inverse(*(outputs + inputs[1:]))
                    # clear memory from outputs
                    # PyTorch 1.0+ way to clear storage
                    for element in outputs:
                        element.storage().resize_(0)

                    if not isinstance(inputs_inverted, tuple):
                        inputs_inverted = (inputs_inverted, )
                    for element_original, element_inverted in zip(
                            inputs, inputs_inverted):
                        element_original.storage().resize_(
                            int(np.prod(element_original.size())))
                        element_original.set_(element_inverted)

        # compute gradients
        with torch.set_grad_enabled(True):
            detached_inputs = []
            for element in inputs:
                if isinstance(element, torch.Tensor):
                    detached_inputs.append(element.detach())
                else:
                    detached_inputs.append(element)
            detached_inputs = tuple(detached_inputs)
            for det_input, requires_grad in zip(detached_inputs,
                                                ctx.input_requires_grad):
                det_input.requires_grad = requires_grad
            temp_output = ctx.fn(*detached_inputs)
        if not isinstance(temp_output, tuple):
            temp_output = (temp_output, )

        filtered_detached_inputs = tuple(
            filter(lambda x: x.requires_grad, detached_inputs))
        gradients = torch.autograd.grad(outputs=temp_output,
                                        inputs=filtered_detached_inputs +
                                        ctx.weights,
                                        grad_outputs=grad_outputs)

        # Setting the gradients manually on the inputs and outputs (mimic backwards)

        input_gradients = []
        i = 0
        for rg in ctx.input_requires_grad:
            if rg:
                input_gradients.append(gradients[i])
                i += 1
            else:
                input_gradients.append(None)

        gradients = tuple(input_gradients) + gradients[-len(ctx.weights):]

        return (None, None, None, None, None, None) + gradients
Exemplo n.º 13
0
 def __enter__(self):
     self._fork = torch.random.fork_rng(devices=self.fwd_gpu_devices,
                                        enabled=True)
     self._fork.__enter__()
     torch.set_rng_state(self.fwd_cpu_state)
     set_device_states(self.fwd_gpu_devices, self.fwd_gpu_states)