def forward(ctx, run_function, parent_ctx_dict, kwarg_keys, *args): if torch.is_grad_enabled( ): # grad may be disabled, e.g., during validation checkpoint.check_backward_validity(args) ctx.run_function = run_function ctx.kwarg_keys = kwarg_keys ctx.fwd_rng_state = utils.get_rng_state() tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) ctx.save_for_backward(*tensor_inputs) ctx.packed_non_tensor_inputs = packed_non_tensor_inputs with torch.no_grad(): unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) outputs = run_function(*unpacked_args, **unpacked_kwargs) if isinstance(outputs, torch.Tensor): return outputs else: # Autograd Functions don't like non-Tensor outputs. We can split the # non-Tensor and Tensor outputs, returning the former by reference # through *parent_ctx_dict* and returning the latter directly. outputs, packed_non_tensor_outputs = split_non_tensors(outputs) parent_ctx_dict[ "packed_non_tensor_outputs"] = packed_non_tensor_outputs return outputs
def forward(ctx, run_function, preserve_rng_state, *args): check_backward_validity(args) ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state if preserve_rng_state: ctx.fwd_cpu_state = torch.get_rng_state() ctx.had_cuda_in_fwd = False if torch.cuda._initialized: ctx.had_cuda_in_fwd = True ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states( *args) ctx.save_for_backward(*args) with torch.no_grad(): outputs = run_function(*args) # return outputs # # Lie to torch we have no None items, to avoid the assert # result = [] for o in outputs: if o is None: o = torch.zeros(0) result.append(o) return tuple(result)
def forward(ctx, run_function, *args): check_backward_validity(args) ctx.run_function = run_function ctx.fwd_rng_state = utils.get_rng_state() ctx.save_for_backward(*args) with torch.no_grad(): outputs = run_function(*args) return outputs
def forward( # type: ignore ctx: Any, dummy_tensor_requires_grad: torch.Tensor, run_function: Any, parent_ctx_dict: Dict[str, Any], kwarg_keys: Tuple[str, ...], *args: Any, **kwargs: Any) -> Any: torch_checkpoint.check_backward_validity(args) ctx.run_function = run_function ctx.kwarg_keys = kwarg_keys ctx.fwd_rng_state = get_rng_state() ctx.had_autocast_in_fwd = is_autocast_enabled() tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) if parent_ctx_dict["offload"]: ctx.fwd_device = tuple(x.device for x in tensor_inputs) ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs) tensor_inputs = tuple( x.to("cpu", non_blocking=True) for x in tensor_inputs) else: ctx.fwd_device, ctx.grad_requirements = None, None ctx.save_for_backward(*tensor_inputs) ctx.packed_non_tensor_inputs = packed_non_tensor_inputs with torch.no_grad(), enable_checkpointing(): unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) outputs = run_function(*unpacked_args, **unpacked_kwargs) the_module = unpacked_args[0] # Because we run with torch.no_grad(), we can't actually access # outputs.requires_grad. Instead, we manually compute it by # checking if either the input or the module needs grads parameters = list(the_module.parameters()) # If the module is wrapped by FlattenParamsWrapper, then the # parameters would have been deleted. If so, we need to access # the views into the flattened parameters. if hasattr(the_module, "_unflattened_param_views"): parameters += the_module._unflattened_param_views output_requires_grad = any(param.requires_grad for param in parameters) or any( x.requires_grad for x in tensor_inputs) parent_ctx_dict["output_requires_grad"] = output_requires_grad if not isinstance(outputs, torch.Tensor): # Autograd Functions don't like non-Tensor outputs. We can split the # non-Tensor and Tensor outputs, returning the former by reference # through *parent_ctx_dict* and returning the latter directly. outputs, packed_non_tensor_outputs = split_non_tensors(outputs) parent_ctx_dict[ "packed_non_tensor_outputs"] = packed_non_tensor_outputs return outputs
def forward(ctx, rev_block_stack, preserve_rng_state, *inputs): ''' :param ctx: context, like self :param rev_block_stack: Module with multiple reversible blocks stacked, such as RevSequential :param preserve_rng_state: Whether to save the random number state, can be used to reproduce the random number only torch random numbers, numpy does not include :param inputs: Multiple tensor, requires that at least one tensor requires_grad is True, otherwise this function will not calculate backpropagation, which is a limitation from torch.autograd.Function :return: ''' # Warn when requires_grad of all inputs tensor is not True check_backward_validity(inputs) # Make sure the input is a list of modules and supports invert operations assert isinstance(rev_block_stack, nn.ModuleList) assert hasattr(rev_block_stack, 'inverse') and callable( rev_block_stack.inverse) ctx.rev_block_stack = rev_block_stack ctx.preserve_rng_state = preserve_rng_state # rng state save # Note that the state should be saved and restored layer by layer if preserve_rng_state: ctx.status_stack = [] ctx.had_cuda_in_fwd = False if torch.cuda._initialized: ctx.had_cuda_in_fwd = True ctx.cuda_status_stack = [] # Since the execution order of the modules is reversed when the back pass is required, # each sub-module needs to save the random number state separately. outputs = inputs for m in ctx.rev_block_stack: fwd_cpu_state = torch.get_rng_state() ctx.status_stack.append(fwd_cpu_state) if ctx.had_cuda_in_fwd: fwd_gpu_devices, fwd_gpu_states = get_device_states( *outputs) ctx.cuda_status_stack.append( [fwd_gpu_devices, fwd_gpu_states]) # Set torch.no_grad because don't save intermediate variables with torch.no_grad(): outputs = m(*outputs) else: # If you don't need to save the random number state, you can run the entire module directly to get the output with torch.no_grad(): outputs = rev_block_stack(*inputs) # Save output for backward function ctx.save_for_backward(*outputs) return outputs
def forward(ctx, run_function, preserve_rng_state, *args): check_backward_validity(args) ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. ctx.gpu_autocast_kwargs = { "enabled": torch.is_autocast_enabled(), "dtype": torch.get_autocast_gpu_dtype(), "cache_enabled": torch.is_autocast_cache_enabled() } ctx.cpu_autocast_kwargs = { "enabled": torch.is_autocast_cpu_enabled(), "dtype": torch.get_autocast_cpu_dtype(), "cache_enabled": torch.is_autocast_cache_enabled() } if preserve_rng_state: ctx.fwd_cpu_state = torch.get_rng_state() # Don't eagerly initialize the cuda context by accident. # (If the user intends that the context is initialized later, within their # run_function, we SHOULD actually stash the cuda state here. Unfortunately, # we have no way to anticipate this will happen before we run the function.) ctx.had_cuda_in_fwd = False if torch.cuda._initialized: ctx.had_cuda_in_fwd = True ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) # Save non-tensor inputs in ctx, keep a placeholder None for tensors # to be filled out during the backward. ctx.inputs = [] ctx.tensor_indices = [] tensor_inputs = [] tensor_outputs = [] for i, arg in enumerate(args): if torch.is_tensor(arg): tensor_inputs.append(arg) ctx.tensor_indices.append(i) ctx.inputs.append(None) else: ctx.inputs.append(arg) ctx.save_for_backward(*tensor_inputs) with torch.no_grad(): outputs = run_function(*args) return outputs
def forward(ctx, run_function, preserve_rng_state, *args): check_backward_validity(args) ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() if preserve_rng_state: ctx.fwd_cpu_state = torch.get_rng_state() # Don't eagerly initialize the cuda context by accident. # (If the user intends that the context is initialized later, within their # run_function, we SHOULD actually stash the cuda state here. Unfortunately, # we have no way to anticipate this will happen before we run the function.) ctx.had_cuda_in_fwd = False if torch.cuda._initialized: ctx.had_cuda_in_fwd = True ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) ctx.save_for_backward(*args) with torch.no_grad(): outputs = run_function(*args) return outputs
def forward( # type: ignore ctx: Any, run_function: Any, parent_ctx_dict: Dict[str, Any], kwarg_keys: Tuple[str, ...], *args: Any, **kwargs: Any ) -> Any: if torch.is_grad_enabled(): # grad may be disabled, e.g., during validation torch_checkpoint.check_backward_validity(args) ctx.run_function = run_function ctx.kwarg_keys = kwarg_keys ctx.fwd_rng_state = get_rng_state() ctx.had_autocast_in_fwd = is_autocast_enabled() tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) if parent_ctx_dict["offload"]: ctx.fwd_device = tuple(x.device for x in tensor_inputs) ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs) tensor_inputs = tuple(x.cpu() for x in tensor_inputs) else: ctx.fwd_device, ctx.grad_requirements = None, None ctx.save_for_backward(*tensor_inputs) ctx.packed_non_tensor_inputs = packed_non_tensor_inputs with torch.no_grad(): unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) outputs = run_function(*unpacked_args, **unpacked_kwargs) the_module = unpacked_args[0] inc_counter(the_module) if not isinstance(outputs, torch.Tensor): # Autograd Functions don't like non-Tensor outputs. We can split the # non-Tensor and Tensor outputs, returning the former by reference # through *parent_ctx_dict* and returning the latter directly. outputs, packed_non_tensor_outputs = split_non_tensors(outputs) parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs return outputs
def forward(ctx, run_function, callback, preserve_rng_state, *args): print('Fwd', args) chkpt.check_backward_validity(args) ctx.callback = callback ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state ctx.had_autocast_in_fwd = torch.is_autocast_enabled() if preserve_rng_state: ctx.fwd_cpu_state = torch.get_rng_state() # Don't eagerly initialize the cuda context by accident. # (If the user intends that the context is initialized later, within their # run_function, we SHOULD actually stash the cuda state here. Unfortunately, # we have no way to anticipate this will happen before we run the function.) ctx.had_cuda_in_fwd = False if torch.cuda._initialized: ctx.had_cuda_in_fwd = True ctx.fwd_gpu_devices, ctx.fwd_gpu_states = chkpt.get_device_states( *args) # Save non-tensor inputs in ctx, keep a placeholder None for tensors # to be filled out during the backward. ctx.inputs = [] ctx.tensor_indices = [] tensor_inputs = [] for i, arg in enumerate(args): if torch.is_tensor(arg): tensor_inputs.append(arg) ctx.tensor_indices.append(i) ctx.inputs.append(None) else: ctx.inputs.append(arg) ctx.save_for_backward(*tensor_inputs) with torch.no_grad(): outputs = run_function(*args) return outputs