예제 #1
0
def emit_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]:
    assert dispatch_strategy(fn) == 'use_derived'
    f = fn.func
    info = fn.info
    fw_derivatives = fn.fw_derivatives

    name = cpp.name(f.func)
    inplace = f.func.kind() == SchemaKind.inplace
    is_out_fn = f.func.kind() == SchemaKind.out
    returns_void = len(f.func.returns) == 0
    base_name = get_base_name(f)
    view_info = get_view_info(fn)

    def gen_differentiable_input(
        arg: Union[Argument, SelfArgument, TensorOptionsArguments]
    ) -> Optional[DifferentiableInput]:
        if isinstance(arg, TensorOptionsArguments):
            return None
        a: Argument = arg.argument if isinstance(arg, SelfArgument) else arg

        # TODO: `cpp_type` is only to keep it byte-for-byte compatible with the old codegen, should remove.
        # NB: This is not a clone of cpp.argument() - TensorOptionsArguments / faithful / binds are
        # not handled properly as they are irrelevant for this codegen.
        cpp_type = cpp.argument_type(a, binds=a.name).cpp_type()

        if not is_differentiable(a.name, a.type, info):
            return None
        return DifferentiableInput(
            name=a.name,
            type=a.type,
            cpp_type=cpp_type,
        )

    @with_native_function
    def gen_differentiable_inputs(
            f: NativeFunction) -> List[DifferentiableInput]:
        return list(
            mapMaybe(gen_differentiable_input, f.func.arguments.non_out))

    def find_args_with_derivatives(
        differentiable_inputs: List[DifferentiableInput]
    ) -> List[DifferentiableInput]:
        """Find arguments that have derivative definitions"""
        if info is None or not info.has_derivatives:
            return differentiable_inputs
        names = set(name for d in info.derivatives for name in d.var_names)
        differentiable = [
            arg for arg in differentiable_inputs if arg.name in names
        ]
        if len(differentiable) != len(names):
            missing = names - set(arg.name for arg in differentiable)
            raise RuntimeError(
                f'Missing arguments for derivatives: {missing} in {info.name}')
        return differentiable

    differentiable_inputs = gen_differentiable_inputs(f)
    args_with_derivatives = find_args_with_derivatives(differentiable_inputs)
    differentiable_outputs = gen_differentiable_outputs(fn)

    undifferentiable = (base_name in DONT_REQUIRE_DERIVATIVE) or (
        name in DONT_REQUIRE_DERIVATIVE)

    requires_derivative = (not undifferentiable) and (
        len(differentiable_inputs) > 0) and (len(differentiable_outputs) > 0)

    requires_fw_derivatives = not undifferentiable and len(fw_derivatives) > 0

    if info is not None and info.has_derivatives and not requires_derivative:
        raise RuntimeError(
            f'ERROR: derivative ignored for {name} -- specified an autograd function without derivative'
        )

    def emit_save_inputs() -> List[str]:
        setup: List[str] = []
        if info is None or not info.has_derivatives:
            return setup

        has_tensorlist_arg = any(
            is_tensor_list_type(arg.type) for arg in args_with_derivatives)

        # We don't want to save tensors if we know that they will never be used
        # when computing the derivative, so we add guards to those statements
        def guard_for(arg: SavedAttribute) -> Optional[str]:
            assert info is not None

            # It's hard to determine the edge offset if we have TensorLists
            if has_tensorlist_arg:
                return None

            # Empirical evaluation of the cases where we insert those guards in
            # backward show that they are somewhat useless. E.g. there's no need
            # to guard on some values captured from forward, because they had to
            # require_grad if the backward function even gets executed. I don't
            # have any good ideas for detecting those cases, so I simply disabled the
            # checks.
            if 'backward' in info.name:
                return None

            # If there's a single derivative we could compute, we already have
            # a requires_grad check that is sufficient
            if len(args_with_derivatives) <= 1:
                return None

            # We really only care about trimming down the amount of tensors we save
            if arg.nctype.type != BaseCType(tensorT):
                return None

            # We want to emit simple guards, so we only allow that if checking one
            # input is enough to determine whether we need that value
            used_in = [d for d in info.derivatives if arg in d.saved_inputs]
            assert len(used_in) > 0
            if len(used_in) != 1:
                return None
            derivative = used_in[0]
            if len(derivative.var_names) != 1:
                return None
            derivative_var_name = derivative.var_names[0]

            # Figure out the offset of the edge that uses this variable
            for edge_off, a in enumerate(args_with_derivatives):
                if a.name == derivative_var_name:
                    break
            else:
                raise AssertionError()

            return f'grad_fn->should_compute_output({edge_off})'

        setup.extend(save_variables(info.all_saved_inputs, False, guard_for))
        for arg in args_with_derivatives:
            if is_tensor_list_type(arg.type):
                setup.append(f'grad_fn->{arg.name}_size_ = {arg.name}.size();')

        return setup

    def setup_derivative(
            differentiable_inputs: List[DifferentiableInput]) -> List[str]:
        body: List[str] = []
        if is_out_fn:
            # For out functions, ensure that no input or output requires grad
            body.append(DECLARE_GRAD_FN.substitute(op='Node'))
            body.append(
                SETUP_NONE_REQUIRES_GRAD.substitute(
                    base_name=base_name,
                    args_to_check=[arg.name for arg in differentiable_inputs]))
            body.append(
                SETUP_NONE_REQUIRES_GRAD.substitute(
                    base_name=base_name,
                    args_to_check=[arg.name
                                   for arg in differentiable_outputs]))
            return body

        op = info.op if info is not None and info.has_derivatives else 'NotImplemented'
        setup = []
        setup.extend(
            ASSIGN_GRAD_FN.substitute(
                op=op,
                op_ctor='' if info is not None and info.has_derivatives else
                f'"{cpp.name(f.func)}"',
                args_with_derivatives=[
                    arg.name for arg in args_with_derivatives
                ],
            ).split('\n'))
        setup.extend(emit_save_inputs())

        body.extend(
            emit_check_no_requires_grad(differentiable_inputs,
                                        args_with_derivatives))
        body.append(DECLARE_GRAD_FN.substitute(op=op))
        body.append(SETUP_DERIVATIVE.substitute(setup=setup))
        return body

    def emit_check_if_in_complex_autograd_allowlist() -> List[str]:
        body: List[str] = []
        if base_name in GRADIENT_IMPLEMENTED_FOR_COMPLEX:
            return body
        for arg in differentiable_outputs:
            name = arg.name
            # TODO: should be `arg.type.is_tensor_like()`?
            if arg.cpp_type in [
                    'at::Tensor', 'at::TensorList',
                    'const c10::List<c10::optional<at::Tensor>> &'
            ]:
                body.append(
                    f'throw_error_for_complex_autograd({name}, "{base_name}");'
                )
        return body

    def emit_check_no_requires_grad(
        tensor_args: List[DifferentiableInput],
        args_with_derivatives: List[DifferentiableInput],
    ) -> List[str]:
        """Checks that arguments without derivatives don't require grad"""
        body: List[str] = []
        for arg in tensor_args:
            if arg in args_with_derivatives:
                continue
            name = arg.name
            if info and name in info.non_differentiable_arg_names:
                continue
            if name == 'output':
                # Double-backwards definitions sometimes take in 'input' and
                # 'output', but only define the derivative for input.
                continue
            body.append(f'check_no_requires_grad({name}, "{name}");')
        return body

    def save_variables(
        saved_variables: Sequence[SavedAttribute],
        is_output: bool,
        guard_for: Callable[[SavedAttribute],
                            Optional[str]] = lambda name: None,
    ) -> Sequence[str]:
        # assign the saved variables to the generated grad_fn
        stmts: List[str] = []
        for arg in saved_variables:
            name = arg.nctype.name.name if isinstance(
                arg.nctype.name, SpecialArgName) else arg.nctype.name
            type = arg.nctype.type
            expr = arg.expr
            if type == BaseCType(tensorT) or type == OptionalCType(BaseCType(tensorT)) or \
                    type == MutRefCType(OptionalCType(BaseCType(tensorT))) or (is_output and type == BaseCType(scalarT)):
                var = name
                name += '_'
                if var == 'self' and inplace:
                    var = 'self.clone()'
                    assert not is_output
                if inplace and is_output:
                    var = 'self'
                    is_inplace_view = f'{var}.is_view()'
                    expr = f'SavedVariable({var}, {str(is_output).lower()}, {is_inplace_view})'
                else:
                    expr = f'SavedVariable({var}, {str(is_output).lower()})'
            elif type == BaseCType(tensorListT) or type == ListCType(
                    OptionalCType(BaseCType(tensorT))):
                expr = f'make_saved_variable_list({name})'
                name += '_'
            elif type == BaseCType(intArrayRefT):
                expr = expr + ".vec()"
            elif type == BaseCType(stringT):
                expr = f'std::string({expr})'
            elif type == OptionalCType(BaseCType(stringT)):
                expr = f'{expr}.has_value() ? c10::optional<std::string>(std::string({expr}.value())) : c10::nullopt'
            guard = guard_for(arg)
            if guard is None:
                stmts.append(f'grad_fn->{name} = {expr};')
            else:
                stmts.append(f'if ({guard}) {{')
                stmts.append(f'  grad_fn->{name} = {expr};')
                stmts.append('}')
        return stmts

    # Generates a Dispatcher::redispatch() call into the dispatcher. We do this mainly for performance reasons:
    #  - Pre-compute the full DispatchKeySet. This saves the dispatcher from having to read from TLS.
    #  - redispatch() avoids a redundant call to RecordFunction, which was already called right before
    #    we entered this autograd kernel.
    def emit_dispatch_call(f: NativeFunction, input_base: str,
                           unpacked_args: Sequence[str]) -> str:
        """ Dispatch call via function in a namespace or method on Tensor."""
        dispatcher_sig = DispatcherSignature.from_schema(f.func)
        dispatcher_exprs = dispatcher_sig.exprs()

        # code-generated autograd kernels plumb and recompute dispatch keys directly through the kernel for performance.
        # Ops also always have a function variant of the redispatch API.
        # See Note [Plumbing Keys Through The Dispatcher] for details.
        dispatch_key_set = 'ks & c10::after_autograd_keyset'
        call = CALL_REDISPATCH.substitute(api_name=cpp.name(
            f.func,
            faithful_name_for_out_overloads=True,
        ),
                                          unpacked_args=[dispatch_key_set] +
                                          list(unpacked_args))
        return call

    def wrap_output(f: NativeFunction, unpacked_bindings: List[Binding],
                    var: str) -> str:
        call = ''
        rhs_value: Optional[str] = None
        if not any(r.type.is_tensor_like() for r in f.func.returns):
            rhs_value = var
        else:
            rhs_value = f'std::move({var})'
        assert rhs_value is not None
        call += ASSIGN_RETURN_VALUE.substitute(
            return_values=tie_return_values(f), rhs_value=rhs_value)
        return call

    def enforce_same_tensorimpl_and_storage(
            call: str, unpacked_bindings: List[Binding]) -> str:
        save_ptrs_stmts: List[str] = []
        enforce_same_ptrs_stmts: List[str] = []
        if cpp.name(f.func) not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
            for unpacked_binding in unpacked_bindings:
                arg = unpacked_binding.name
                noref_cpp_type = unpacked_binding.nctype.type.remove_const_ref(
                )
                if noref_cpp_type == BaseCType(tensorListT):
                    save_ptrs_stmts += [
                        SAVE_TENSORLIST_STORAGE.substitute(
                            tensorlist_name=arg),
                        SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg)
                    ]
                    enforce_same_ptrs_stmts += [
                        ENFORCE_SAME_TENSORLIST_STORAGE.substitute(
                            tensorlist_name=arg),
                        ENFORCE_SAME_TENSORLIST_IMPL.substitute(
                            tensorlist_name=arg)
                    ]
                elif noref_cpp_type == ListCType(
                        OptionalCType(BaseCType(tensorT))):
                    save_ptrs_stmts += [
                        SAVE_OPTIONALTENSORLIST_STORAGE.substitute(
                            tensorlist_name=arg),
                        SAVE_OPTIONALTENSORLIST_IMPL.substitute(
                            tensorlist_name=arg)
                    ]
                    enforce_same_ptrs_stmts += [
                        ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute(
                            tensorlist_name=arg),
                        ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute(
                            tensorlist_name=arg)
                    ]
                elif noref_cpp_type == BaseCType(tensorT):
                    save_ptrs_stmts += [
                        SAVE_TENSOR_STORAGE.substitute(tensor_name=arg),
                        SAVE_TENSOR_IMPL.substitute(tensor_name=arg)
                    ]
                    enforce_same_ptrs_stmts += [
                        ENFORCE_SAME_TENSOR_STORAGE.substitute(
                            tensor_name=arg),
                        ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg)
                    ]
        assert (save_ptrs_stmts and enforce_same_ptrs_stmts) or (
            not save_ptrs_stmts and not enforce_same_ptrs_stmts)
        if save_ptrs_stmts and enforce_same_ptrs_stmts:
            call = RUN_ONLY_IN_DEBUG_MODE.substitute(statements=save_ptrs_stmts) + \
                call + \
                RUN_ONLY_IN_DEBUG_MODE.substitute(statements=enforce_same_ptrs_stmts)
        return call

    def emit_call(f: NativeFunction, unpacked_bindings: List[Binding]) -> str:
        # We only care about adding `at::AutoDispatchBelowAutograd` guard for non-variable dispatch
        # (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure
        # the baseType operations still dispatch to non-Variable type, even if the arguments passed
        # in are now Variables.
        # See NOTE [ Treating Variables as non-Variables in type dispatch ] for details.
        unpacked_args = [b.name for b in unpacked_bindings]
        base_type_call = emit_dispatch_call(f, 'self_', unpacked_args)

        if get_view_info(fn) is not None or modifies_arguments(f):
            guard = 'at::AutoDispatchBelowAutograd guard;'
        else:
            guard = 'at::AutoDispatchBelowADInplaceOrView guard;'

        if not modifies_arguments(f) and not returns_void:
            call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES.substitute(
                base_type_call=base_type_call, tmp_var=TMP_VAR, guard=guard)

            call += wrap_output(f, unpacked_bindings, TMP_VAR)
        else:
            call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute(
                base_type_call=base_type_call, guard=guard)
        call = enforce_same_tensorimpl_and_storage(call, unpacked_bindings)
        return call

    def emit_history() -> str:
        fn = 'rebase' if modifies_arguments(f) and view_info is None else 'set'
        output_names = [r.name for r in differentiable_outputs]
        # TODO: flatten allocates a std::vector, which could be expensive
        outs = CodeTemplate("flatten_tensor_args( ${outs} )").substitute(
            outs=output_names)
        return SET_HISTORY.substitute(fn=fn, differentiable_outputs=outs)

    def emit_save_outputs() -> str:
        if is_out_fn:
            # out functions don't currently support differentiation
            return ''
        if info is not None and info.has_derivatives:
            stmts = save_variables(info.all_saved_outputs, True)
            if len(stmts) == 0:
                return ''
            return CONDITIONAL.substitute(cond='grad_fn', statements=stmts)
        return ''

    def emit_any_requires_grad() -> List[str]:
        return [
            SETUP_ANY_REQUIRES_GRAD.substitute(args_with_derivatives=[
                arg.name for arg in args_with_derivatives
            ]),
        ]

    def emit_check_inplace() -> List[str]:
        if not inplace:
            return []
        return [
            f'check_inplace({arg.name}, _any_requires_grad);'
            for arg in differentiable_outputs
        ]

    def emit_fw_derivatives() -> List[str]:
        content: List[str] = []
        for derivative in fw_derivatives:
            res = derivative.var_name
            if f.func.name.name.inplace:
                # TODO update this when inplace namings are unified
                res = "self"

            assert derivative.required_inputs_fw_grad is not None
            requires_fw_grad = " || ".join([
                FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
                for inp in differentiable_inputs
                if inp.name in derivative.required_inputs_fw_grad
            ])
            if not requires_fw_grad:
                # Handle functions like stack
                # For these, we don't unpack anything and always call the user function
                if not (len(differentiable_inputs) == 1 and
                        is_tensor_list_type(differentiable_inputs[0].type)):
                    raise RuntimeError(
                        f'No differentiable input to "{name}" is a differentiable Tensor (as the provided'
                        'forward AD formula does not use any input tangent) even though a forward gradient '
                        'formula has been defined for it. This case should only happen for function that '
                        'take a single TensorList as input. All other cases are not supported right now.'
                    )
                requires_fw_grad = "true"
            unpacked_arguments = ""
            for inp in differentiable_inputs:
                if inp.name in derivative.required_inputs_fw_grad:
                    unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(
                        inp=inp.name)
                if inp.name in (derivative.required_inputs_primal or []):
                    unpacked_arguments += FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(
                        inp=inp.name)

            if inplace:
                is_inplace_str = "true"
            else:
                is_inplace_str = "false"

            if isinstance(derivative.var_type,
                          BaseType) and derivative.var_type.is_tensor_like():
                fw_grad_setter = FW_DERIVATIVE_SETTER_TENSOR.substitute(
                    out_arg=res, is_inplace=is_inplace_str)
            elif isinstance(derivative.var_type,
                            ListType) and derivative.var_type.is_tensor_like():
                fw_grad_setter = FW_DERIVATIVE_SETTER_TENSOR_LIST.substitute(
                    out_arg=res, is_inplace=is_inplace_str)
            else:
                raise RuntimeError(
                    "Unsupported output type for forward derivative")
            # View ops create fw_grad that already is a view of the base's fw_grad so just use that
            content.append(
                FW_DERIVATIVE_TEMPLATE.substitute(
                    requires_fw_grad=requires_fw_grad,
                    formula=derivative.formula,
                    out_arg=res,
                    unpacked_arguments=unpacked_arguments,
                    fw_grad_setter=fw_grad_setter))
        return content

    def emit_forbid_fw_derivatives(is_inplace: bool = False) -> str:
        def get_msg() -> str:
            if is_inplace:
                msg = name + " (because it is inplace)"
            else:
                msg = name
            return msg

        res = ""
        to_check: List[str] = []
        for inp in differentiable_inputs:
            if is_tensor_type(inp.type):
                to_check.append(
                    FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name))
            elif is_tensor_list_type(inp.type):
                cond = FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp="_t")
                res += FW_DERIVATIVE_FORBID_LIST_TEMPLATE.substitute(
                    arg=inp.name, cond=cond, msg=get_msg())
            else:
                raise RuntimeError(
                    f'Unsupported input type for "{name}" when forbidding forward AD usage.'
                )

        if len(to_check) > 0:
            cond = " || ".join(to_check)
            res += FW_DERIVATIVE_FORBID_TEMPLATE.substitute(cond=cond,
                                                            msg=get_msg())
        return res

    body: List[str] = []
    unpack_args_stats, unpacked_bindings = unpack_args(f)

    body.extend(unpack_args_stats)
    if requires_derivative:
        body.extend(emit_any_requires_grad())
        body.extend(emit_check_inplace())
        body.extend(setup_derivative(differentiable_inputs))
    body.append(declare_returned_variables(f))

    body.append(emit_call(f, unpacked_bindings))
    if requires_derivative:
        # set_flags has to appear after version_counter, because rebase_history
        # requires that the counter is incremented before it is called
        body.append(emit_history())
        body.extend(emit_check_if_in_complex_autograd_allowlist())

    if is_out_fn:
        body.append(emit_forbid_fw_derivatives(is_inplace=True))
    else:
        if requires_fw_derivatives:
            body.extend(emit_fw_derivatives())
        else:
            body.append(emit_forbid_fw_derivatives())

    if requires_derivative:
        # Save only after the forward AD has been set up
        body.append(emit_save_outputs())

    if base_name in RESET_GRAD_ACCUMULATOR:
        # `inplace` implies that there is exactly one output named `self`,
        # so we can keep the generated code easy. If you need to
        # `reset_grad_accumulator` in an operator that's not `inplace`, you can
        # remove this assert but the code generation will get more elaborate
        assert inplace
        body.append('reset_grad_accumulator(self);')
    if not returns_void:
        body.append(f'return {get_return_value(f)};')
    return body
def use_derived(fn: NativeFunctionWithDifferentiabilityInfo) -> bool:
    f = fn.func
    name = cpp.name(f.func)
    return name not in MANUAL_AUTOGRAD and dispatch_strategy(
        fn) == 'use_derived'