Example #1
0
 def canonical_function(functions: Sequence[NativeFunction],
                        name: str) -> NativeFunction:
     for f in functions:
         if cpp.name(f.func) == name:
             return f
     # some functions only have in-place variants
     assert name + "_" == cpp.name(functions[0].func)
     return functions[0]
Example #2
0
def format_prerecord_trace(f: NativeFunction) -> str:
    if not should_trace(f):
        return ""

    # TODO: clean up old codegen behavior
    is_inplace = (
        f.func.kind() in (SchemaKind.inplace, SchemaKind.out)
        and not f.func.name.name.dunder_method
    )
    add_args = (
        RENAME_TRACE_ADD_ARGS.get(f.func.name.name.base, "") if is_inplace else ""
    )
    additional_inputs = (
        SELECT.substitute(
            cond="tracer_state->force_outplace",
            true=add_args,
            false="",
        )
        if add_args
        else ""
    )

    return PRE_RECORD_TRACE.substitute(
        set_op_name=format_trace_op_name(f),
        add_trace_inputs=format_trace_inputs(f) + additional_inputs,
        inplace_guard=INPLACE_GUARD.substitute(
            name=cpp.name(f.func),
            mutable_input=f.func.arguments.out[0].name
            if f.func.arguments.out
            else "self",
        )
        if is_inplace
        else "",
    )
Example #3
0
def is_factory_function(f: NativeFunction) -> bool:
    if Variant.function not in f.variants:
        return False

    name = cpp.name(f.func)
    has_tensor_options = python.has_tensor_options(f)
    return has_tensor_options or name.endswith("_like")
Example #4
0
def process_function(f: NativeFunction) -> Optional[str]:
    name = cpp.name(f.func)
    has_tensor_options = python.has_tensor_options(f)
    is_factory = has_tensor_options or name.endswith("_like")

    if Variant.function not in f.variants or not is_factory:
        return None

    sig = CppSignatureGroup.from_native_function(f, method=False).signature
    formals: List[str] = []
    exprs: List[str] = []
    requires_grad = "false"
    for arg in sig.arguments():
        qualified_type = fully_qualified_type(arg.type)
        if arg.default:
            formals.append(f"{qualified_type} {arg.name} = {arg.default}")
        else:
            formals.append(f"{qualified_type} {arg.name}")

        if isinstance(arg.argument, TensorOptionsArguments):
            # note: we remove the requires_grad setting from the TensorOptions because
            # it is ignored anyways (and we actually have an assertion that it isn't set
            # which would fail otherwise). We handle requires_grad explicitly here
            # instead of passing it through to the kernel.
            exprs.append(
                f"at::TensorOptions({arg.name}).requires_grad(c10::nullopt)")
            # Manually set the requires_grad bit on the result tensor.
            requires_grad = f"{arg.name}.requires_grad()"
        else:
            exprs.append(arg.name)

    return f"""\
Example #5
0
def gen_trace_type(out: str, native_functions: List[NativeFunction],
                   template_path: str) -> None:
    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
    # template regarding sharding of the generated files.
    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    fm.write_sharded(
        "TraceType.cpp",
        [
            fn for fn in native_functions
            if cpp.name(fn.func) not in MANUAL_TRACER
        ],
        key_fn=lambda fn: fn.root_name,
        base_env={
            "generated_comment":
            f"@generated from {template_path}/TraceType.cpp",
        },
        env_callable=gen_trace_type_func,
        num_shards=5,
        sharded_keys={
            "ops_headers",
            "trace_method_definitions",
            "trace_wrapper_registrations",
        },
    )
Example #6
0
def emit_namedtuple_call(
    overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> Tuple[List[str], Dict[str, str]]:
    """
    Generate block of named tuple type def inits, and add typeref snippets
    to declarations that use them
    """
    typenames: Dict[str, str] = {
    }  # map from unique name + field name lists to typedef name
    typedefs: List[str] = []  # typedef declarations and init code

    for overload in overloads:
        fieldnames = namedtuple_fieldnames(overload.function.func.returns)
        if not fieldnames:
            continue

        name = cpp.name(overload.function.func)  # use @with_native_function?
        tn_key = gen_namedtuple_typename_key(overload.function)
        typename = typenames.get(tn_key)
        if typename is None:
            typename = f'NamedTuple{"" if not typedefs else len(typedefs)}'
            typenames[tn_key] = typename
            typedefs.append(f"""\
static PyTypeObject* {typename} = get_namedtuple("{name}");""")

    return typedefs, typenames
Example #7
0
def method_registration(f: NativeFunction) -> str:
    assert cpp.name(f.func) not in MANUAL_TRACER

    return WRAPPER_REGISTRATION.substitute(
        name=f.func.name,
        type_wrapper_name=type_wrapper_name(f),
        class_type="TraceType",
    )
Example #8
0
def emit_view_lambda(f: NativeFunction, unpacked_bindings: List[Binding]) -> str:
    """Generate an additional lambda function to recover views in backward when as_strided is not supported.
    See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details."""
    input_base = "input_base"
    replay_view_func = ""
    updated_unpacked_args: List[str] = []
    known_view_arg_simple_types: List[CType] = [
        BaseCType(longT),
        OptionalCType(BaseCType(longT)),
        BaseCType(boolT),
        BaseCType(intArrayRefT),
        BaseCType(symIntArrayRefT),
    ]
    for unpacked_binding in unpacked_bindings:
        arg, arg_type = unpacked_binding.name, unpacked_binding.nctype.type
        if arg == "self_":
            updated_unpacked_args.append(input_base)
            continue
        if arg_type not in known_view_arg_simple_types:
            known_types_str = ", ".join([str(t) for t in known_view_arg_simple_types])
            raise TypeError(
                f"You are adding an {arg_type} {arg} argument to op {cpp.name(f.func)} in addition to known types: "
                f"{known_types_str}. Please update the list or materialize it so that it can be closed "
                "over by value, also add a test in pytorch/xla/test/test_operations.py where this code "
                "is exercised."
            )

        if arg_type == BaseCType(intArrayRefT) or arg_type == BaseCType(
            symIntArrayRefT
        ):
            # It's not safe to close over IntArrayRef by value, since this is a
            # reference type, so materialize a vector to close over by value
            arg_vec = arg + "_vec"
            replay_view_func += ARRAYREF_TO_VEC.substitute(arg=arg, vec=arg_vec)
            updated_unpacked_args.append(arg_vec)
        elif arg_type == OptionalCType(BaseCType(longT)):
            # Materialize int64_t? to int64_t
            arg_value = arg + "_val"
            replay_view_func += OPTIONAL_TO_VAL.substitute(
                arg=arg, val=arg_value, default="0"
            )
            updated_unpacked_args.append(arg_value)
        else:
            updated_unpacked_args.append(arg)

    replay_view_call = emit_view_call(f, input_base, updated_unpacked_args)
    replay_view_func += REPLAY_VIEW_LAMBDA_FUNC.substitute(
        input_base=input_base, replay_view_call=replay_view_call
    )

    is_view_with_metadata_change = (
        "true" if cpp.name(f.func) in VIEW_FUNCTIONS_WITH_METADATA_CHANGE else "false"
    )

    return SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE.substitute(
        is_view_with_metadata_change=is_view_with_metadata_change,
        replay_view_func=replay_view_func,
    )
Example #9
0
 def canonical_function(functions: Sequence[NativeFunction],
                        name: str) -> NativeFunction:
     for f in functions:
         if (not f.func.is_functional_fn() and not f.func.is_out_fn()
                 and name == str(f.func.name.name)):
             return f
     # some functions only have in-place variants
     assert name + "_" == cpp.name(functions[0].func)
     return functions[0]
Example #10
0
def generate_call_to_view_ops(g: NativeFunctionsViewGroup,
                              backend_index: BackendIndex) -> str:
    schema = g.view.func
    kernel_name = cpp.name(schema)
    kernel = backend_index.get_kernel(g.view)
    if kernel:
        kernel_name = kernel.kernel
    arg_names = (arg.name for arg in schema.schema_order_arguments())
    namespace_name = "native"
    return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'
Example #11
0
    def go(f: NativeFunction) -> str:
        # header comments
        if isinstance(ps, PythonSignatureDeprecated):
            schema_comment = f"// [deprecated] aten::{ps.deprecated_schema}"
        else:
            schema_comment = f"// aten::{f.func}"

        deprecated = "[deprecated] " if ps.deprecated else ""

        # dispatch lambda signature
        name = cpp.name(f.func)
        lambda_formals = ", ".join(
            map(lambda a: f"{a.type_str} {a.name}",
                dispatch_lambda_args(ps, f)))
        lambda_return = dispatch_lambda_return_str(f)

        # dispatch lambda body
        dispatch_callee = cpp_dispatch_target(f)
        dispatch_args = ", ".join(cpp_dispatch_exprs(f, python_signature=ps))

        # from arg parser outputs to dispatch lambda arguments
        parser_outputs = arg_parser_output_exprs(ps, f)
        lambda_arg_exprs = dispatch_lambda_exprs(ps, f)
        inits = "\n".join(lambda_arg_exprs.inits)
        lambda_args = ", ".join(lambda_arg_exprs.exprs)

        # scatter fields
        # TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky
        #       solution for enabling the 'requires_grad' argument for tensor methods
        #       new_full, new_empty, and new_zeros. A much better but more difficult to
        #       implement solution involves refactoring according to Ed's description here:
        #       https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589
        need_set_requires_grad = ps.tensor_options_args and (
            not has_tensor_options(f) or (ps.method and
                                          ("requires_grad" in parser_outputs)))
        set_requires_grad = (
            f'.set_requires_grad({parser_outputs["requires_grad"].expr})'
            if need_set_requires_grad else "")

        if lambda_return == "void":
            return f"""\
{schema_comment}
{inits}
auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{
  pybind11::gil_scoped_release no_gil;
  {dispatch_callee}({dispatch_args});
}};
dispatch_{name}({lambda_args}){set_requires_grad};
Py_RETURN_NONE;
"""
        else:
            typename = namedtuple_typenames.get(gen_namedtuple_typename_key(f))
            namedtuple_typeref = f"{typename}, " if typename is not None else ""
            return f"""\
def should_generate_py_binding(f: NativeFunction) -> bool:
    name = cpp.name(f.func)
    for skip_regex in SKIP_PYTHON_BINDINGS:
        if skip_regex.match(name):
            return False

    signature = str(f.func)
    for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
        if pattern == signature:
            return False

    return True
Example #13
0
def generate_out_variant_call(g: NativeFunctionsGroup) -> str:
    schema = g.out.func
    assert schema.is_out_fn()
    arg_names = [out_arg.name for out_arg in schema.arguments.out]
    for arg in schema.arguments.non_out:
        if isinstance(arg, SelfArgument):
            arg_names.append(arg.argument.name)
        else:
            assert isinstance(arg, Argument)
            arg_names.append(arg.name)
    cpp_func_name = cpp.name(schema)
    cpp_arg_names = ",".join(arg_names)
    return f"at::cpu::{cpp_func_name}({cpp_arg_names})"
Example #14
0
def type_wrapper_name(f: NativeFunction, key: str = "Default") -> str:
    if f.func.name.overload_name:
        name = f"{cpp.name(f.func)}_{f.func.name.overload_name}"
    else:
        name = cpp.name(f.func)

    # The key argument is only used in gen_variable_type where we need fns per autograd dispatch key.
    # In gen_trace_type and gen_inplace_view_type where only one fn per native_fn must be generated,
    # the key argument should not be passed.
    # We do not append key if it is Default so that generated functions from
    # before per-dispatch-key derivatives were added retain the same names.
    if key != "Default":
        name = name + f"_{key}"
    return name
def generate_return_type_definition_and_map_entry(
    overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> Tuple[List[str], List[str]]:
    """
    Generate block of function in `python_return_types.cpp` to initialize
    and return named tuple for a native function which returns named tuple
    and relevant entry for the map in same file.
    """
    typenames: Dict[
        str, str
    ] = {}  # map from unique name + field name lists to typedef name
    definitions: List[str] = []  # function defintion to register the typedef
    map_entries: List[
        str
    ] = []  # C++ map entry of <function_name, function creates it namedtuple>

    for overload in overloads:
        fieldnames = namedtuple_fieldnames(overload.function.func.returns)
        if not fieldnames:
            continue

        fields = ", ".join(f'{{"{fn}", ""}}' for fn in fieldnames)

        name = cpp.name(overload.function.func)  # use @with_native_function?
        tn_key = gen_namedtuple_typename_key(overload.function)
        typename = typenames.get(tn_key)

        if typename is None:
            typename = f'{name}NamedTuple{"" if not definitions else len(definitions)}'
            typenames[tn_key] = typename
            definitions.append(
                f"""\
PyTypeObject* get_{name}_namedtuple() {{
    static PyStructSequence_Field NamedTuple_fields[] = {{ {fields},  {{nullptr}} }};
    static PyTypeObject {typename};
    static bool is_initialized = false;
    static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, NamedTuple_fields, {len(fieldnames)} }};
    if (!is_initialized) {{
        PyStructSequence_InitType(&{typename}, &desc);
        {typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr;
        is_initialized = true;
    }}
    return &{typename};
}}
"""
            )
            map_entries.append(f'{{"{name}", get_{name}_namedtuple()}}, ')

    return definitions, map_entries
Example #16
0
def gen_autograd(
    native_functions_path: str,
    tags_path: str,
    out: str,
    autograd_dir: str,
    operator_selector: SelectiveBuilder,
    disable_autograd: bool = False,
) -> None:
    # Parse and load derivatives.yaml
    differentiability_infos, used_dispatch_keys = load_derivatives(
        os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path,
        tags_path)

    template_path = os.path.join(autograd_dir, "templates")

    native_funcs = parse_native_yaml(native_functions_path,
                                     tags_path).native_functions
    fns = list(
        sorted(
            filter(operator_selector.is_native_function_selected_for_training,
                   native_funcs),
            key=lambda f: cpp.name(f.func),
        ))
    fns_with_diff_infos: List[
        NativeFunctionWithDifferentiabilityInfo] = match_differentiability_info(
            fns, differentiability_infos)

    # Generate VariableType.h/cpp
    if not disable_autograd:
        gen_variable_type(
            out,
            native_functions_path,
            tags_path,
            fns_with_diff_infos,
            template_path,
            used_dispatch_keys,
        )

        gen_inplace_or_view_type(out, native_functions_path, tags_path,
                                 fns_with_diff_infos, template_path)

        # operator filter not applied as tracing sources are excluded in selective build
        gen_trace_type(out, native_funcs, template_path)
    # Generate Functions.h/cpp
    gen_autograd_functions_lib(out, differentiability_infos, template_path)

    # Generate variable_factories.h
    gen_variable_factories(out, native_functions_path, tags_path,
                           template_path)
Example #17
0
def should_generate_py_binding(f: NativeFunction) -> bool:
    # So far, all NativeFunctions that are entirely code-generated do not get python bindings.
    if "generated" in f.tags:
        return False
    name = cpp.name(f.func)
    for skip_regex in SKIP_PYTHON_BINDINGS:
        if skip_regex.match(name):
            return False

    signature = str(f.func)
    for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
        if pattern == signature:
            return False

    return True
Example #18
0
def method_definition(f: NativeFunction) -> str:
    assert cpp.name(f.func) not in MANUAL_TRACER

    formals = ", ".join(
        # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance.
        # See Note [Plumbing Keys Through The Dispatcher] for details.
        ["c10::DispatchKeySet ks"] + [
            f'{cpp.argument_type(a, binds="__placeholder__").cpp_type()} {a.name}'
            for a in f.func.schema_order_arguments()
        ])

    return METHOD_DEFINITION.substitute(
        return_type=cpp.returns_type(f.func.returns).cpp_type(),
        type_wrapper_name=type_wrapper_name(f),
        formals=formals,
        type_definition_body=emit_trace_body(f),
    )
Example #19
0
def process_function(f: NativeFunction) -> Optional[str]:
    name = cpp.name(f.func)
    has_tensor_options = python.has_tensor_options(f)
    is_factory = has_tensor_options or name.endswith("_like")

    if Variant.function not in f.variants or not is_factory:
        return None

    cpp_sigs = CppSignatureGroup.from_native_function(f, method=False)
    sigs = [cpp_sigs.signature]
    if cpp_sigs.symint_signature is not None:
        sigs.append(cpp_sigs.symint_signature)
    r = ""
    for sig in sigs:
        formals: List[str] = []
        exprs: List[str] = []
        requires_grad = "false"
        for arg in sig.arguments():
            qualified_type = fully_qualified_type(arg.type)
            if arg.default:
                formals.append(f"{qualified_type} {arg.name} = {arg.default}")
            else:
                formals.append(f"{qualified_type} {arg.name}")

            if isinstance(arg.argument, TensorOptionsArguments):
                # note: we remove the requires_grad setting from the TensorOptions because
                # it is ignored anyways (and we actually have an assertion that it isn't set
                # which would fail otherwise). We handle requires_grad explicitly here
                # instead of passing it through to the kernel.
                exprs.append(
                    f"at::TensorOptions({arg.name}).requires_grad(c10::nullopt)"
                )
                # Manually set the requires_grad bit on the result tensor.
                requires_grad = f"{arg.name}.requires_grad()"
            else:
                exprs.append(arg.name)

        r += f"""\
inline at::Tensor {sig.name()}({', '.join(formals)}) {{
  at::AutoDispatchBelowADInplaceOrView guard;
  return autograd::make_variable(at::{sig.name()}({', '.join(exprs)}), /*requires_grad=*/{requires_grad});
}}
"""
    return r
Example #20
0
def format_trace_op_name(f: NativeFunction) -> str:
    # TODO: byte-for-byte compatible with old codegen behavior - should clean up
    if (f.func.kind() in (SchemaKind.functional, SchemaKind.out)
            or f.func.name.name.dunder_method):
        # special case for *_out functions: the in-place and out-of-place ops
        # are overloaded with the same name in the JIT
        trace_name = str(f.func.name.name)
        trace_name = RENAME_TRACE.get(trace_name, trace_name)
        return OP_NAME.substitute(trace_name=trace_name)

    # otherwise, this is an in-place op and we need to emit both in- and
    # out-of-place versions
    outplace_trace_name = f.func.name.name.base
    inplace_trace_name = cpp.name(f.func)
    outplace_trace_name = RENAME_TRACE.get(outplace_trace_name,
                                           outplace_trace_name)
    inplace_trace_name = RENAME_TRACE.get(inplace_trace_name,
                                          inplace_trace_name)

    return SELECT.substitute(
        cond="tracer_state->force_outplace",
        true=OP_NAME.substitute(trace_name=outplace_trace_name),
        false=OP_NAME.substitute(trace_name=inplace_trace_name),
    )
Example #21
0
def generate_out_variant_call(g: NativeFunctionsGroup,
                              backend_index: BackendIndex) -> str:
    schema = g.out.func
    assert schema.is_out_fn()
    arg_names = []
    kernel_name = get_out_kernel_name(g, backend_index)
    if g.structured:
        # structured op starts with the output tensor argument.
        arg_names = [out_arg.name for out_arg in schema.arguments.out]
    else:
        arg_names = []
    for arg in schema.arguments.non_out:
        if isinstance(arg, SelfArgument):
            arg_names.append(arg.argument.name)
        else:
            assert isinstance(arg, Argument)
            arg_names.append(arg.name)
    if not g.structured:
        assert len(schema.arguments.out) == 1
        arg_names.append(schema.arguments.out[0].name)
    cpp_func_name = cpp.name(schema)
    cpp_arg_names = ",".join(arg_names)
    namespace_name = "cpu" if g.structured else "native"
    return f"at::{namespace_name}::{kernel_name}({cpp_arg_names})"
Example #22
0
def name(func: FunctionSchema) -> str:
    return cpp.name(func)
Example #23
0
def gen_namedtuple_typename_key(f: NativeFunction) -> str:
    name = cpp.name(f.func)
    fieldnames = namedtuple_fieldnames(f.func.returns)
    return "_".join([name] + fieldnames)
def generate_function(
    f: NativeFunction, k: SchemaKind
) -> Tuple[NativeFunction, Dict[DispatchKey, Dict["OperatorName",
                                                  "BackendMetadata"]]]:
    from torchgen.api import cpp

    if k == SchemaKind.functional:
        assert f.func.kind() != SchemaKind.functional
        gets_composite_kernel = True
        # The new "functional" NativeFunction has:
        # - any mutable arguments have been converted into (immutable) returns.
        #   (if a mutable argument was not also a return, it gets converted to one)
        # - a "functional" overload name.
        # The default grouping logic in signature() actually already does this,
        # so we can piggy-back off it (but we still want return names)
        func = f.func.signature(keep_return_names=True).with_name(
            f.func.name.remove_inplace().with_overload(
                "functional" if not f.func.name.overload_name else
                f"{f.func.name.overload_name}_functional"))
    elif k == SchemaKind.out:
        # We generate out= ops mostly just so that we can pair up NativeFunctions into groups easily,
        # but at least today, there is no good reason to actually use them.
        # we'll generate a dispatcher entry for them, but won't actually register any kernels for them.
        gets_composite_kernel = False
        if f.func.kind() == SchemaKind.inplace:
            func = self_to_out_signature(f.func)
        elif f.func.kind() == SchemaKind.mutable:
            func = mutable_to_out_signature(f.func)
        else:
            raise AssertionError(
                "We only bother generating out= functions from either inplace or mutable variants"
            )
    else:
        raise AssertionError(
            "We currently only generate either functional or out= NativeFunctions"
        )

    if gets_composite_kernel:
        backend_metadata = {
            DispatchKey.CompositeExplicitAutograd: {
                func.name: BackendMetadata(cpp.name(func), structured=False)
            }
        }
    else:
        backend_metadata = {}

    return (
        NativeFunction(
            func=func,
            use_const_ref_for_mutable_tensors=f.
            use_const_ref_for_mutable_tensors,
            # These generated fn's aren't meant to be user friendly- don't generate methods.
            variants=set([Variant.function]),
            structured=False,
            structured_delegate=None,
            structured_inherits=None,
            precomputed=None,
            autogen=[],
            ufunc_inner_loop={},
            manual_kernel_registration=False,
            manual_cpp_binding=False,
            python_module=None,
            category_override=None,
            device_guard=False,
            device_check=DeviceCheckType.NoCheck,
            loc=f.loc,
            cpp_no_default_args=set(),
            is_abstract=f.is_abstract,
            has_composite_implicit_autograd_kernel=False,
            has_composite_explicit_autograd_kernel=gets_composite_kernel,
            # Every generated NativeFunction gets a "generated" tag, so it's easy to tell
            # which NativeFunction objects did not come directly from native_functions.yaml.
            tags=set(["generated"]),
        ),
        backend_metadata,
    )
Example #25
0
def type_wrapper_name(f: NativeFunction) -> str:
    if f.func.name.overload_name:
        return f"{cpp.name(f.func)}_{f.func.name.overload_name}"
    else:
        return cpp.name(f.func)
def generate_function(
    f: NativeFunction, k: SchemaKind
) -> Tuple[NativeFunction, Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]]]:
    from torchgen.api import cpp

    if k == SchemaKind.functional:
        assert f.func.kind() != SchemaKind.functional
        # The new "functional" NativeFunction has:
        # - any mutable arguments have been converted into (immutable) returns.
        #   (if a mutable argument was not also a return, it gets converted to one)
        # - "_functional" appended to the base name, ONLY IF this op has a mutable variant.
        #   See Note [Overload Ambiguity With Functional Variants]
        # The default grouping logic in signature() actually already does this,
        # so we can piggy-back off it (but we still want return names)
        func = f.func.signature(keep_return_names=True).with_name(
            OperatorName(
                name=BaseOperatorName(
                    base=f.func.name.name.base,
                    inplace=False,
                    dunder_method=f.func.name.name.dunder_method,
                    # See Note [Overload Ambiguity With Functional Variants]
                    functional_overload=f.func.kind() == SchemaKind.mutable,
                ),
                overload_name=f.func.name.overload_name,
            )
        )
    elif k == SchemaKind.out:
        # We generate out= ops mostly just so that we can pair up NativeFunctions into groups easily,
        # but at least today, there is no good reason to actually use them.
        # we'll generate a dispatcher entry for them, but won't actually register any kernels for them.
        if f.func.kind() == SchemaKind.inplace:
            func = self_to_out_signature(f.func)
        elif f.func.kind() == SchemaKind.mutable:
            func = mutable_to_out_signature(f.func)
        elif f.func.kind() == SchemaKind.functional:
            func = functional_to_out_signature(f.func)
        else:
            raise AssertionError(
                "We only bother generating out= functions from either inplace or mutable or functional variants"
            )
    else:
        raise AssertionError(
            "We currently only generate either functional or out= NativeFunctions"
        )

    # Generated kernel naming convention for out: <op_name>_<overload_name>. The reason for this is to
    # disambiguate operator with the same name but different overload name, e.g., `randn.names_out` and
    # `randn.generator_with_names_out`.
    kernel_name = (
        func.name.unambiguous_name()
        if func.kind() == SchemaKind.out
        else cpp.name(func)
    )
    backend_metadata = {
        DispatchKey.CompositeExplicitAutograd: {
            func.name: BackendMetadata(
                kernel=kernel_name,
                structured=False,
                cpp_namespace=DEFAULT_KERNEL_NAMESPACE,
            )
        }
    }

    return (
        NativeFunction(
            func=func,
            use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
            # These generated fn's aren't meant to be user friendly- don't generate methods.
            variants=set([Variant.function]),
            structured=False,
            structured_delegate=None,
            structured_inherits=None,
            precomputed=None,
            autogen=[],
            ufunc_inner_loop={},
            manual_kernel_registration=False,
            manual_cpp_binding=False,
            python_module=None,
            category_override=None,
            device_guard=False,
            device_check=DeviceCheckType.NoCheck,
            loc=f.loc,
            cpp_no_default_args=set(),
            is_abstract=f.is_abstract,
            has_composite_implicit_autograd_kernel=False,
            has_composite_explicit_autograd_kernel=True,
            has_composite_explicit_autograd_non_functional_kernel=False,
            # Every generated NativeFunction gets a "generated" tag, so it's easy to tell
            # which NativeFunction objects did not come directly from native_functions.yaml.
            tags=set(["generated"]) | (f.tags & {"nondeterministic_seeded"}),
            namespace=f.namespace,
        ),
        backend_metadata,
    )
Example #27
0
def create_differentiability_info(
    defn: Dict[Any, Any],
    functions_by_signature: Dict[FunctionSchema, List[NativeFunction]],
    functions_by_schema: Dict[str, NativeFunction],
    op_counter: Counter[str],
) -> DifferentiabilityInfo:
    """Processes a single entry `defn` in derivatives.yaml"""
    def canonical_function(functions: Sequence[NativeFunction],
                           name: str) -> NativeFunction:
        for f in functions:
            if (not f.func.is_functional_fn() and not f.func.is_out_fn()
                    and name == str(f.func.name.name)):
                return f
        # some functions only have in-place variants
        assert name + "_" == cpp.name(functions[0].func)
        return functions[0]

    def split_names(raw_names: str) -> Tuple[str, ...]:
        """Given "foo, bar", return ["foo", "bar"]."""
        return tuple(x.strip() for x in raw_names.split(","))

    def check_grad_usage(defn_name: str,
                         derivatives: Sequence[Derivative]) -> None:
        """
        Check for some subtle mistakes one might make when writing derivatives.
        These mistakes will compile, but will be latent until a function is
        used with double backwards.
        """

        uses_grad = False  # true if any derivative uses "grad"
        num_grads_uses = 0  # count of uses of "grads" or "grads[INDEX]"
        uses_named_grads = False  # true if any derivative uses "grad_{name}"
        used_grads_indices: List[int] = []  # which indices of grads are used
        for d in derivatives:
            formula = d.formula
            uses_grad = uses_grad or bool(
                re.findall(IDENT_REGEX.format("grad"), formula))
            num_grads_uses += len(
                re.findall(IDENT_REGEX.format("grads"), formula))
            uses_named_grads = uses_named_grads or bool(d.named_gradients)
            used_grads_indices.extend(used_gradient_indices(formula))
        # This is a basic sanity check: the number of places we see
        # "grads" should be no fewer than the number of indices we see
        # inside "grads". They may not be equal because we may use
        # "grads" without an index.
        assert num_grads_uses >= len(used_grads_indices)
        # Thus if the number is equal, every use of grads is also
        # indexed.
        only_used_grads_indices = num_grads_uses == len(used_grads_indices)

        if uses_grad and num_grads_uses > 0:
            raise RuntimeError(
                f"Derivative definition of {defn_name} in derivatives.yaml illegally "
                "mixes use of 'grad' and 'grads'. Consider replacing "
                "occurrences of 'grad' with 'grads[0]'")

        if only_used_grads_indices and set(used_grads_indices) == {0}:
            raise RuntimeError(
                f"Derivative definition of {defn_name} in derivatives.yaml solely "
                "refers to 'grads[0]'.  If the first output is indeed the "
                "only differentiable output, replace 'grads[0]' with 'grad'; "
                "otherwise, there is a likely error in your derivatives "
                "declaration.")

        if uses_named_grads and (uses_grad or num_grads_uses > 0):
            raise RuntimeError(
                f"Derivative definition of {defn_name} in derivatives.yaml illegally "
                'mixes use of "grad_RETURN_NAME" and "grad" or "grads[x]". Use '
                "only one method for identifying gradients.")

    @with_native_function
    def set_up_derivatives(
        f: NativeFunction,
    ) -> Tuple[Sequence[Derivative], Sequence[ForwardDerivative],
               Sequence[Binding], Sequence[str], Sequence[str], ]:
        # Set up the derivative information
        derivatives: List[Derivative] = []
        forward_derivatives: List[ForwardDerivative] = []
        non_differentiable_arg_names: List[str] = []
        args_with_derivatives_set: Set[str] = set()

        all_arg_names = [a.name for a in cpp_arguments(f)]
        all_ret_names = [r.name for r in f.func.returns
                         ]  # only used for the assert below
        # output_differentiability is captured from the enclosed
        # scope. Don't modify it.
        #
        # If it is not present, then no output is explicitly
        # undifferentiable.
        #
        # It may be present and shorter than the length of return
        # values. If that's the case, any return value that does not
        # have a corresponding entry is considered not differentiable.
        differentiability = output_differentiability or [True] * len(
            f.func.returns)
        # A return is available as a named gradient ...
        available_named_gradients = [
            f"grad_{ret.name}"
            for ret, differentiable in zip(f.func.returns, differentiability)
            # if it has not been explicitly made undifferentiable
            if differentiable
            # and if it has a name
            and ret.name is not None
            # and if its type is differentiable
            and ret.type.is_tensor_like()
        ]

        for raw_names in sorted(defn.keys()):
            formula = defn[raw_names]
            names = split_names(raw_names)

            for name in names:
                assert not (name in all_arg_names and name in all_ret_names), (
                    f"While processing the derivative formula for '{f.func.name}' wrt '{name}', "
                    f"expected '{name}' to not be both an input arg and named return. "
                )

            if is_forward_derivative_definition(all_arg_names, names):
                forward_derivatives.append(
                    create_forward_derivative(f, formula, names))
            else:
                if formula.lower().strip() == "non_differentiable":
                    non_differentiable_arg_names += names
                else:
                    derivative = create_derivative(f, formula, names,
                                                   available_named_gradients)
                    derivatives.append(derivative)
                    args_with_derivatives_set |= set(names)

        overlap = args_with_derivatives_set.intersection(
            non_differentiable_arg_names)
        if overlap:
            raise RuntimeError(
                f"derivatives definition for {defn} have overlapped non_differentiable "
                f"and differentiable variables: {overlap}")

        # Next, let us determine the list of inputs in order.
        # TODO: do we need eagerly calculate and save it here? Can it be derived
        # from NativeFunction and `derivatives` on callsites instead?
        args_with_derivatives = [
            a for a in cpp_arguments(f) if a.name in args_with_derivatives_set
        ]

        # Postprocess forward derivatives definitions now that we know the differentiable arguments
        forward_derivatives = postprocess_forward_derivatives(
            f,
            defn_name,
            all_arg_names,
            derivatives,
            forward_derivatives,
            args_with_derivatives,
        )

        # Test to see if the use of 'grads' makes sense.
        check_grad_usage(defn_name, derivatives)

        return (
            derivatives,
            forward_derivatives,
            args_with_derivatives,
            non_differentiable_arg_names,
            available_named_gradients,
        )

    # NB: Removes 'name' from defn dictionary
    specification = defn.pop("name")
    defn_name, _ = split_name_params(specification)
    # NB: Removes 'output_differentiability' from defn dictionary
    #     `None` means all differentiable.
    output_differentiability = defn.pop("output_differentiability", None)
    output_differentiability_conditions = None
    if output_differentiability and any(
        [isinstance(diff, str) for diff in output_differentiability]):
        if len(output_differentiability) != 1:
            raise RuntimeError(
                f"Not supported: for {specification},"
                f"output_differentiability must either be "
                f"List[bool] or a List[str] where each str is a "
                f"condition. In the case where it is a condition, "
                f"we only support single-output functions. "
                f"Please file us an issue. ")
        output_differentiability_conditions = output_differentiability
        output_differentiability = [True]

    schema_function = functions_by_schema.get(specification)
    if not schema_function:
        avail = "\n".join(k for k, v in functions_by_schema.items()
                          if cpp.name(v.func) == defn_name)
        raise RuntimeError(
            f"could not find ATen function for schema: {specification} "
            f".  Available signatures:\n{avail}")

    # now map this to the legacy schema; this isn't technically necessary, but we'd need some logic here
    # to map in-place schemas to the out-of-place variants.
    # TODO: maybe the logic to handle the legacy schema is no longer necessary?
    signature = schema_function.func.signature()
    functions = functions_by_signature[signature]
    if len(functions) == 0:
        avail = "\n".join(
            str(k) for k, v in functions_by_signature.items()
            if cpp.name(k) == defn_name)
        raise RuntimeError(
            f"could not find ATen function for legacy signature: {signature} "
            f"corresponding to schema {specification}.  Please report a bug to PyTorch. "
            f"Available signatures:\n{avail}")

    canonical = canonical_function(functions, defn_name)
    if "grad_input_mask" in (a.name for a in cpp_arguments(canonical)):
        raise RuntimeError(
            f"Schema for {defn_name} has an argument named grad_input_mask, "
            "but this name would be shadowed by our codegen. "
            "Please use a different name in native_functions.yaml.")

    if "result" in (a.name for a in cpp_arguments(canonical)):
        raise RuntimeError(
            f"Schema for {defn_name} has an argument named result, "
            "but this is only allowed for outputs."
            "Please use a different name in native_functions.yaml.")

    (
        derivatives,
        forward_derivatives,
        args_with_derivatives,
        non_differentiable_arg_names,
        available_named_gradients,
    ) = set_up_derivatives(canonical)

    used_named_gradients: Set[str] = set()
    for d in derivatives:
        used_named_gradients |= d.named_gradients

    # only assign an op name if we are actually going to calculate a derivative
    op = None
    if args_with_derivatives:
        op_prefix = _create_op_prefix(defn_name)
        op = f"{op_prefix}{op_counter[op_prefix]}"
        op_counter[op_prefix] += 1

    return DifferentiabilityInfo(
        name=defn_name,
        func=canonical,
        op=op,
        derivatives=derivatives,
        forward_derivatives=forward_derivatives,
        all_saved_inputs=dedup_vars(
            [v for d in derivatives for v in d.saved_inputs]),
        all_saved_outputs=dedup_vars(
            [v for d in derivatives for v in d.saved_outputs]),
        available_named_gradients=available_named_gradients,
        used_named_gradients=used_named_gradients,
        args_with_derivatives=args_with_derivatives,
        non_differentiable_arg_names=non_differentiable_arg_names,
        output_differentiability=output_differentiability,
        output_differentiability_conditions=output_differentiability_conditions,
    )
Example #28
0
def use_derived(fn: NativeFunctionWithDifferentiabilityInfo) -> bool:
    f = fn.func
    name = cpp.name(f.func)
    return name not in MANUAL_AUTOGRAD and dispatch_strategy(
        fn) == "use_derived"
Example #29
0
def wrapper_name(func: FunctionSchema) -> str:
    if func.name.overload_name:
        return f"{cpp.name(func)}_{func.name.overload_name}"
    else:
        return cpp.name(func)
Example #30
0
def get_out_kernel_name(g: NativeFunctionsGroup,
                        backend_index: BackendIndex) -> str:
    kernel = backend_index.get_kernel(g.out)
    if g.structured or kernel is None:
        return cpp.name(g.out.func)
    return kernel.kernel