コード例 #1
0
def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None):
    CCT = generate_cct(enable_recursive_torch_dispatch=True, autograd_view_consistency=False)
    # Permutations of arg and kwargs in CCT.
    for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT):
        new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice

        def maybe_tangent(t):
            assert type(t) is not CCT
            # Generate `tangent` tensor
            # if given object is a Tensor and requires grad is set.
            if isinstance(t, torch.Tensor) and t.requires_grad:
                return torch.randn_like(t)
            elif is_tensorlist(t):
                return list(torch.randn_like(e) if e.requires_grad else None for e in t)
            return None

        tangent_args = tuple(maybe_tangent(arg) for arg in args)
        flat_kwargs, spec = tree_flatten(kwargs)
        flat_tangent_kwargs = tuple(maybe_tangent(arg) for arg in flat_kwargs)
        tangent_kwargs = tree_unflatten(flat_tangent_kwargs, spec)

        # Permutations tangent arg and tangent kwargs in CCT.
        for tang_choice in generate_subclass_choices_args_kwargs(tangent_args, tangent_kwargs, CCT):
            new_tang_args, new_tang_kwargs, \
                which_tang_args_are_wrapped, which_tang_kwargs_are_wrapped = tang_choice

            with fwAD.dual_level():
                def maybe_make_dual(dual):
                    # Returns dual tensor if primal is a tensor/tensor subclass
                    # with requires_grad set.
                    primal, tangent = dual
                    if isinstance(primal, torch.Tensor) and primal.requires_grad:
                        return fwAD.make_dual(primal, tangent)
                    elif is_tensorlist(primal):
                        return tuple(fwAD.make_dual(pri, tang) if tang is not None else pri
                                     for pri, tang in zip(primal, tangent))
                    return primal

                op_args = tuple(map(maybe_make_dual, zip(new_args, new_tang_args)))
                op_kwargs = {k: maybe_make_dual((v, new_tang_kwargs[k])) for k, v in new_kwargs.items()}

                try:
                    if gradcheck_wrapper is None:
                        op(*op_args, **op_kwargs)
                    else:
                        gradcheck_wrapper(op, *op_args, **op_kwargs)
                # see NOTE: [What errors are Composite Compiance trying to catch?]
                except RuntimeError as err:
                    raise_composite_compliance_error(
                        err,
                        f"- wrapped_args: {which_args_are_wrapped}\n"
                        f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
                        f"- wrapped_tangent_args: {which_tang_args_are_wrapped}\n"
                        f"- wrapped_tangent_kwargs: {which_tang_kwargs_are_wrapped}\n"
                    )
コード例 #2
0
import torch.autograd.forward_ad as fwAD

primal = torch.randn(10, 10)
tangent = torch.randn(10, 10)

def fn(x, y):
    return x ** 2 + y ** 2

# All forward AD computation must be performed in the context of
# a ``dual_level`` context. All dual tensors created in such a context
# will have their tangents destroyed upon exit. This is to ensure that
# if the output or intermediate results of this computation are reused
# in a future forward AD computation, their tangents (which are associated
# with this computation) won't be confused with tangents from the later
# computation.
with fwAD.dual_level():
    # To create a dual tensor we associate a tensor, which we call the
    # primal with another tensor of the same size, which we call the tangent.
    # If the layout of the tangent is different from that of the primal,
    # The values of the tangent are copied into a new tensor with the same
    # metadata as the primal. Otherwise, the tangent itself is used as-is.
    #
    # It is also important to note that the dual tensor created by
    # ``make_dual`` is a view of the primal.
    dual_input = fwAD.make_dual(primal, tangent)
    assert fwAD.unpack_dual(dual_input).tangent is tangent

    # To demonstrate the case where the copy of the tangent happens,
    # we pass in a tangent with a layout different from that of the primal
    dual_input_alt = fwAD.make_dual(primal, tangent.T)
    assert fwAD.unpack_dual(dual_input_alt).tangent is not tangent