def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None): CCT = generate_cct(enable_recursive_torch_dispatch=True, autograd_view_consistency=False) # Permutations of arg and kwargs in CCT. for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT): new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice def maybe_tangent(t): assert type(t) is not CCT # Generate `tangent` tensor # if given object is a Tensor and requires grad is set. if isinstance(t, torch.Tensor) and t.requires_grad: return torch.randn_like(t) elif is_tensorlist(t): return list(torch.randn_like(e) if e.requires_grad else None for e in t) return None tangent_args = tuple(maybe_tangent(arg) for arg in args) flat_kwargs, spec = tree_flatten(kwargs) flat_tangent_kwargs = tuple(maybe_tangent(arg) for arg in flat_kwargs) tangent_kwargs = tree_unflatten(flat_tangent_kwargs, spec) # Permutations tangent arg and tangent kwargs in CCT. for tang_choice in generate_subclass_choices_args_kwargs(tangent_args, tangent_kwargs, CCT): new_tang_args, new_tang_kwargs, \ which_tang_args_are_wrapped, which_tang_kwargs_are_wrapped = tang_choice with fwAD.dual_level(): def maybe_make_dual(dual): # Returns dual tensor if primal is a tensor/tensor subclass # with requires_grad set. primal, tangent = dual if isinstance(primal, torch.Tensor) and primal.requires_grad: return fwAD.make_dual(primal, tangent) elif is_tensorlist(primal): return tuple(fwAD.make_dual(pri, tang) if tang is not None else pri for pri, tang in zip(primal, tangent)) return primal op_args = tuple(map(maybe_make_dual, zip(new_args, new_tang_args))) op_kwargs = {k: maybe_make_dual((v, new_tang_kwargs[k])) for k, v in new_kwargs.items()} try: if gradcheck_wrapper is None: op(*op_args, **op_kwargs) else: gradcheck_wrapper(op, *op_args, **op_kwargs) # see NOTE: [What errors are Composite Compiance trying to catch?] except RuntimeError as err: raise_composite_compliance_error( err, f"- wrapped_args: {which_args_are_wrapped}\n" f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n" f"- wrapped_tangent_args: {which_tang_args_are_wrapped}\n" f"- wrapped_tangent_kwargs: {which_tang_kwargs_are_wrapped}\n" )
import torch.autograd.forward_ad as fwAD primal = torch.randn(10, 10) tangent = torch.randn(10, 10) def fn(x, y): return x ** 2 + y ** 2 # All forward AD computation must be performed in the context of # a ``dual_level`` context. All dual tensors created in such a context # will have their tangents destroyed upon exit. This is to ensure that # if the output or intermediate results of this computation are reused # in a future forward AD computation, their tangents (which are associated # with this computation) won't be confused with tangents from the later # computation. with fwAD.dual_level(): # To create a dual tensor we associate a tensor, which we call the # primal with another tensor of the same size, which we call the tangent. # If the layout of the tangent is different from that of the primal, # The values of the tangent are copied into a new tensor with the same # metadata as the primal. Otherwise, the tangent itself is used as-is. # # It is also important to note that the dual tensor created by # ``make_dual`` is a view of the primal. dual_input = fwAD.make_dual(primal, tangent) assert fwAD.unpack_dual(dual_input).tangent is tangent # To demonstrate the case where the copy of the tangent happens, # we pass in a tangent with a layout different from that of the primal dual_input_alt = fwAD.make_dual(primal, tangent.T) assert fwAD.unpack_dual(dual_input_alt).tangent is not tangent