コード例 #1
0
    def forward(ctx, run_function, parent_ctx_dict, kwarg_keys, *args):
        if torch.is_grad_enabled(
        ):  # grad may be disabled, e.g., during validation
            checkpoint.check_backward_validity(args)

        ctx.run_function = run_function
        ctx.kwarg_keys = kwarg_keys
        ctx.fwd_rng_state = utils.get_rng_state()

        tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args)
        ctx.save_for_backward(*tensor_inputs)
        ctx.packed_non_tensor_inputs = packed_non_tensor_inputs

        with torch.no_grad():
            unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args)
            outputs = run_function(*unpacked_args, **unpacked_kwargs)

        if isinstance(outputs, torch.Tensor):
            return outputs
        else:
            # Autograd Functions don't like non-Tensor outputs. We can split the
            # non-Tensor and Tensor outputs, returning the former by reference
            # through *parent_ctx_dict* and returning the latter directly.
            outputs, packed_non_tensor_outputs = split_non_tensors(outputs)
            parent_ctx_dict[
                "packed_non_tensor_outputs"] = packed_non_tensor_outputs
            return outputs
コード例 #2
0
    def forward(ctx, run_function, preserve_rng_state, *args):
        check_backward_validity(args)
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state
        if preserve_rng_state:
            ctx.fwd_cpu_state = torch.get_rng_state()
            ctx.had_cuda_in_fwd = False
            if torch.cuda._initialized:
                ctx.had_cuda_in_fwd = True
                ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(
                    *args)
        ctx.save_for_backward(*args)
        with torch.no_grad():
            outputs = run_function(*args)
        # return outputs

        #
        # Lie to torch we have no None items, to avoid the assert
        #
        result = []
        for o in outputs:
            if o is None:
                o = torch.zeros(0)
            result.append(o)

        return tuple(result)
コード例 #3
0
ファイル: random.py プロジェクト: fairseq/Megatron-LM
 def forward(ctx, run_function, *args):
     check_backward_validity(args)
     ctx.run_function = run_function
     ctx.fwd_rng_state = utils.get_rng_state()
     ctx.save_for_backward(*args)
     with torch.no_grad():
         outputs = run_function(*args)
     return outputs
コード例 #4
0
    def forward(  # type: ignore
            ctx: Any, dummy_tensor_requires_grad: torch.Tensor,
            run_function: Any, parent_ctx_dict: Dict[str, Any],
            kwarg_keys: Tuple[str, ...], *args: Any, **kwargs: Any) -> Any:
        torch_checkpoint.check_backward_validity(args)

        ctx.run_function = run_function
        ctx.kwarg_keys = kwarg_keys
        ctx.fwd_rng_state = get_rng_state()
        ctx.had_autocast_in_fwd = is_autocast_enabled()

        tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args)
        if parent_ctx_dict["offload"]:
            ctx.fwd_device = tuple(x.device for x in tensor_inputs)
            ctx.grad_requirements = tuple(x.requires_grad
                                          for x in tensor_inputs)
            tensor_inputs = tuple(
                x.to("cpu", non_blocking=True) for x in tensor_inputs)
        else:
            ctx.fwd_device, ctx.grad_requirements = None, None

        ctx.save_for_backward(*tensor_inputs)
        ctx.packed_non_tensor_inputs = packed_non_tensor_inputs

        with torch.no_grad(), enable_checkpointing():
            unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args)
            outputs = run_function(*unpacked_args, **unpacked_kwargs)
            the_module = unpacked_args[0]

        # Because we run with torch.no_grad(), we can't actually access
        # outputs.requires_grad. Instead, we manually compute it by
        # checking if either the input or the module needs grads
        parameters = list(the_module.parameters())

        # If the module is wrapped by FlattenParamsWrapper, then the
        # parameters would have been deleted. If so, we need to access
        # the views into the flattened parameters.
        if hasattr(the_module, "_unflattened_param_views"):
            parameters += the_module._unflattened_param_views

        output_requires_grad = any(param.requires_grad
                                   for param in parameters) or any(
                                       x.requires_grad for x in tensor_inputs)
        parent_ctx_dict["output_requires_grad"] = output_requires_grad

        if not isinstance(outputs, torch.Tensor):
            # Autograd Functions don't like non-Tensor outputs. We can split the
            # non-Tensor and Tensor outputs, returning the former by reference
            # through *parent_ctx_dict* and returning the latter directly.
            outputs, packed_non_tensor_outputs = split_non_tensors(outputs)
            parent_ctx_dict[
                "packed_non_tensor_outputs"] = packed_non_tensor_outputs

        return outputs
コード例 #5
0
ファイル: rev_utils.py プロジェクト: xuanyuzhou98/higher
    def forward(ctx, rev_block_stack, preserve_rng_state, *inputs):
        '''
        :param ctx:                 context, like self
        :param rev_block_stack:     Module with multiple reversible blocks stacked, such as RevSequential
        :param preserve_rng_state:  Whether to save the random number state, can be used to reproduce the random number
                                    only torch random numbers, numpy does not include
        :param inputs:              Multiple tensor, requires that at least one tensor requires_grad is True,
                                    otherwise this function will not calculate backpropagation,
                                     which is a limitation from torch.autograd.Function
        :return:
        '''
        # Warn when requires_grad of all inputs tensor is not True
        check_backward_validity(inputs)
        # Make sure the input is a list of modules and supports invert operations
        assert isinstance(rev_block_stack, nn.ModuleList)
        assert hasattr(rev_block_stack, 'inverse') and callable(
            rev_block_stack.inverse)

        ctx.rev_block_stack = rev_block_stack
        ctx.preserve_rng_state = preserve_rng_state

        # rng state save
        # Note that the state should be saved and restored layer by layer
        if preserve_rng_state:
            ctx.status_stack = []
            ctx.had_cuda_in_fwd = False
            if torch.cuda._initialized:
                ctx.had_cuda_in_fwd = True
                ctx.cuda_status_stack = []

            # Since the execution order of the modules is reversed when the back pass is required,
            # each sub-module needs to save the random number state separately.
            outputs = inputs
            for m in ctx.rev_block_stack:
                fwd_cpu_state = torch.get_rng_state()
                ctx.status_stack.append(fwd_cpu_state)
                if ctx.had_cuda_in_fwd:
                    fwd_gpu_devices, fwd_gpu_states = get_device_states(
                        *outputs)
                    ctx.cuda_status_stack.append(
                        [fwd_gpu_devices, fwd_gpu_states])
                # Set torch.no_grad because don't save intermediate variables
                with torch.no_grad():
                    outputs = m(*outputs)

        else:
            # If you don't need to save the random number state, you can run the entire module directly to get the output
            with torch.no_grad():
                outputs = rev_block_stack(*inputs)

        # Save output for backward function
        ctx.save_for_backward(*outputs)
        return outputs
コード例 #6
0
ファイル: checkpoint.py プロジェクト: pytorch/xla
  def forward(ctx, run_function, preserve_rng_state, *args):
    check_backward_validity(args)
    ctx.run_function = run_function
    ctx.preserve_rng_state = preserve_rng_state
    # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
    ctx.gpu_autocast_kwargs = {
        "enabled": torch.is_autocast_enabled(),
        "dtype": torch.get_autocast_gpu_dtype(),
        "cache_enabled": torch.is_autocast_cache_enabled()
    }
    ctx.cpu_autocast_kwargs = {
        "enabled": torch.is_autocast_cpu_enabled(),
        "dtype": torch.get_autocast_cpu_dtype(),
        "cache_enabled": torch.is_autocast_cache_enabled()
    }
    if preserve_rng_state:
      ctx.fwd_cpu_state = torch.get_rng_state()
      # Don't eagerly initialize the cuda context by accident.
      # (If the user intends that the context is initialized later, within their
      # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
      # we have no way to anticipate this will happen before we run the function.)
      ctx.had_cuda_in_fwd = False
      if torch.cuda._initialized:
        ctx.had_cuda_in_fwd = True
        ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)

    # Save non-tensor inputs in ctx, keep a placeholder None for tensors
    # to be filled out during the backward.
    ctx.inputs = []
    ctx.tensor_indices = []
    tensor_inputs = []
    tensor_outputs = []
    for i, arg in enumerate(args):
      if torch.is_tensor(arg):
        tensor_inputs.append(arg)
        ctx.tensor_indices.append(i)
        ctx.inputs.append(None)
      else:
        ctx.inputs.append(arg)

    ctx.save_for_backward(*tensor_inputs)

    with torch.no_grad():
      outputs = run_function(*args)

    return outputs
コード例 #7
0
ファイル: random.py プロジェクト: myleott/Megatron-LM
 def forward(ctx, run_function, preserve_rng_state, *args):
     check_backward_validity(args)
     ctx.run_function = run_function
     ctx.preserve_rng_state = preserve_rng_state
     ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
     if preserve_rng_state:
         ctx.fwd_cpu_state = torch.get_rng_state()
         # Don't eagerly initialize the cuda context by accident.
         # (If the user intends that the context is initialized later, within their
         # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
         # we have no way to anticipate this will happen before we run the function.)
         ctx.had_cuda_in_fwd = False
         if torch.cuda._initialized:
             ctx.had_cuda_in_fwd = True
             ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
     ctx.save_for_backward(*args)
     with torch.no_grad():
         outputs = run_function(*args)
     return outputs
コード例 #8
0
    def forward(  # type: ignore
        ctx: Any,
        run_function: Any,
        parent_ctx_dict: Dict[str, Any],
        kwarg_keys: Tuple[str, ...],
        *args: Any,
        **kwargs: Any
    ) -> Any:
        if torch.is_grad_enabled():  # grad may be disabled, e.g., during validation
            torch_checkpoint.check_backward_validity(args)

        ctx.run_function = run_function
        ctx.kwarg_keys = kwarg_keys
        ctx.fwd_rng_state = get_rng_state()
        ctx.had_autocast_in_fwd = is_autocast_enabled()

        tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args)
        if parent_ctx_dict["offload"]:
            ctx.fwd_device = tuple(x.device for x in tensor_inputs)
            ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs)
            tensor_inputs = tuple(x.cpu() for x in tensor_inputs)
        else:
            ctx.fwd_device, ctx.grad_requirements = None, None

        ctx.save_for_backward(*tensor_inputs)
        ctx.packed_non_tensor_inputs = packed_non_tensor_inputs

        with torch.no_grad():
            unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args)
            outputs = run_function(*unpacked_args, **unpacked_kwargs)
            the_module = unpacked_args[0]
            inc_counter(the_module)

        if not isinstance(outputs, torch.Tensor):
            # Autograd Functions don't like non-Tensor outputs. We can split the
            # non-Tensor and Tensor outputs, returning the former by reference
            # through *parent_ctx_dict* and returning the latter directly.
            outputs, packed_non_tensor_outputs = split_non_tensors(outputs)
            parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs
        return outputs
コード例 #9
0
ファイル: checkpointing.py プロジェクト: lupoglaz/OpenFold2
    def forward(ctx, run_function, callback, preserve_rng_state, *args):
        print('Fwd', args)
        chkpt.check_backward_validity(args)
        ctx.callback = callback
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state
        ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
        if preserve_rng_state:
            ctx.fwd_cpu_state = torch.get_rng_state()
            # Don't eagerly initialize the cuda context by accident.
            # (If the user intends that the context is initialized later, within their
            # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
            # we have no way to anticipate this will happen before we run the function.)
            ctx.had_cuda_in_fwd = False
            if torch.cuda._initialized:
                ctx.had_cuda_in_fwd = True
                ctx.fwd_gpu_devices, ctx.fwd_gpu_states = chkpt.get_device_states(
                    *args)

        # Save non-tensor inputs in ctx, keep a placeholder None for tensors
        # to be filled out during the backward.
        ctx.inputs = []
        ctx.tensor_indices = []
        tensor_inputs = []
        for i, arg in enumerate(args):
            if torch.is_tensor(arg):
                tensor_inputs.append(arg)
                ctx.tensor_indices.append(i)
                ctx.inputs.append(None)
            else:
                ctx.inputs.append(arg)

        ctx.save_for_backward(*tensor_inputs)

        with torch.no_grad():
            outputs = run_function(*args)
        return outputs