Esempio n. 1
0
def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments], *,
             is_out: bool) -> List[Binding]:
    # Ideally, we NEVER default native functions.  However, there are a number
    # of functions that call native:: directly and rely on the defaulting
    # existing.  So for BC, we generate defaults for non-out variants (but not
    # for out variants, where it is impossible to generate an appropriate
    # default)
    should_default = not is_out
    if isinstance(a, Argument):
        default: Optional[str] = None
        if should_default and a.default is not None:
            default = cpp.default_expr(a.default, a.type)
        return [
            Binding(
                nctype=argument_type(a, binds=a.name),
                name=a.name,
                default=default,
                argument=a,
            )
        ]
    elif isinstance(a, SelfArgument):
        # Erase SelfArgument from the distinction
        return argument(a.argument, is_out=is_out)
    elif isinstance(a, TensorOptionsArguments):
        default = None
        if should_default:
            default = "{}"
        # TODO: Not sure why the arguments assigned here are for
        # TensorOptionsArguments and not the constituent pieces.  It seems
        # to matter
        return [
            Binding(
                nctype=NamedCType("dtype",
                                  OptionalCType(BaseCType(scalarTypeT))),
                name="dtype",
                default=default,
                argument=a,
            ),
            Binding(
                nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))),
                name="layout",
                default=default,
                argument=a,
            ),
            Binding(
                nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))),
                name="device",
                default=default,
                argument=a,
            ),
            Binding(
                nctype=NamedCType("pin_memory",
                                  OptionalCType(BaseCType(boolT))),
                name="pin_memory",
                default=default,
                argument=a,
            ),
        ]
    else:
        assert_never(a)
Esempio n. 2
0
    def gen_class_set_output_body(self, k: SchemaKind,
                                  maybe_create_proxy: bool) -> str:
        if self.backend_index.dispatch_key in [
                DispatchKey.CUDA,
                DispatchKey.MPS,
                DispatchKey.CompositeExplicitAutograd,
        ]:
            maybe_set_guard = """
auto current_device = guard_.current_device();
if (C10_UNLIKELY(current_device.has_value())) {
  TORCH_INTERNAL_ASSERT(*current_device == options.device(),
    "structured kernels don't support multi-device outputs");
} else {
  guard_.reset_device(options.device());
}
"""
            maybe_set_guard_line = maybe_set_guard + "\n"
        else:
            maybe_set_guard_line = maybe_set_guard = ""

        if maybe_create_proxy:
            create_proxy = """
auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options);
if (C10_UNLIKELY(maybe_proxy.has_value())) {
    proxy_outputs_[output_idx] = c10::ExclusivelyOwned<Tensor>(std::move(maybe_proxy).value());
}
"""
        else:
            create_proxy = ""

        if k is SchemaKind.functional:
            assert self.backend_index.dispatch_key in (
                DispatchKey.Meta,
                DispatchKey.CPU,
                DispatchKey.CUDA,
                DispatchKey.MPS,
                DispatchKey.CompositeExplicitAutograd,
            )
            return f"""{maybe_set_guard_line}
outputs_[output_idx] = create_out(sizes, strides, options);"""
        elif k is SchemaKind.inplace:
            return f"""{maybe_set_guard_line}
const auto& out = outputs_[output_idx].get();
check_inplace(out, sizes, options);
{create_proxy}"""
        elif k is SchemaKind.out:
            return f"""{maybe_set_guard_line}
const auto& out = outputs_[output_idx].get();
resize_out(out, sizes, strides, options);
{create_proxy}"""
        elif k is SchemaKind.mutable:
            raise AssertionError(
                "SchemaKind.mutable structured operators are currently not supported"
            )
        else:
            assert_never(k)
Esempio n. 3
0
 def to_argument(
     a: Union[Argument, TensorOptionsArguments, SelfArgument]
 ) -> List[Argument]:
     if isinstance(a, Argument):
         return [a]
     elif isinstance(a, SelfArgument):
         return [a.argument]
     elif isinstance(a, TensorOptionsArguments):
         return [a.dtype, a.layout, a.device, a.pin_memory]
     else:
         assert_never(a)
Esempio n. 4
0
 def gen_class_ctor(self, k: SchemaKind, class_name: str,
                    returns: int) -> str:
     if k is SchemaKind.functional:
         return ""
     elif k is SchemaKind.inplace:
         # TODO: Make sure out argument is guaranteed to be self
         return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}"
     elif k is SchemaKind.out:
         out_args = ", ".join(f"Tensor& out{i}" for i in range(returns))
         out_refs = ", ".join(f"std::ref(out{i})" for i in range(returns))
         return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}"
     else:
         assert_never(k)
Esempio n. 5
0
 def gen_class_ctor(self, k: SchemaKind, class_name: str,
                    returns: int) -> str:
     if k is SchemaKind.functional:
         return ""
     elif k is SchemaKind.inplace:
         # TODO: Make sure out argument is guaranteed to be self
         return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}"
     elif k is SchemaKind.out:
         out_args = ", ".join(f"Tensor& out{i}" for i in range(returns))
         out_refs = ", ".join(f"std::ref(out{i})" for i in range(returns))
         return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}"
     elif k is SchemaKind.mutable or k is SchemaKind.scratch:
         raise AssertionError(
             f"{k} structured operators are currently not supported")
     else:
         assert_never(k)
Esempio n. 6
0
def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments]) -> List[Binding]:
    if isinstance(a, Argument):
        return [
            Binding(
                nctype=argument_type(a, binds=a.name),
                name=a.name,
                default=None,
                argument=a,
            )
        ]
    elif isinstance(a, SelfArgument):
        return argument(a.argument)
    elif isinstance(a, TensorOptionsArguments):
        raise AssertionError("structured kernels don't support TensorOptions yet")
    else:
        assert_never(a)
Esempio n. 7
0
 def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
     if isinstance(f, NativeFunctionsGroup):
         g: NativeFunctionsGroup = f
         # Note: We call gen_structured() if the operator is marked structured, regardless of the backend.
         # gen_structured() has special logic to handle auto-generated kernels.
         if g.structured:
             return self.gen_structured(g)
         else:
             return list(
                 mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions())
             )
     elif isinstance(f, NativeFunction):
         r = self.gen_unstructured(f)
         return [] if r is None else [r]
     else:
         assert_never(f)
Esempio n. 8
0
    def gen_class_set_output_body(self, k: SchemaKind) -> str:
        if self.backend_index.dispatch_key in [
                DispatchKey.CUDA,
                DispatchKey.MPS,
                DispatchKey.CompositeExplicitAutograd,
        ]:
            maybe_set_guard = """
auto current_device = guard_.current_device();
if (C10_UNLIKELY(current_device.has_value())) {
  TORCH_INTERNAL_ASSERT(*current_device == options.device(),
    "structured kernels don't support multi-device outputs");
} else {
  guard_.reset_device(options.device());
}
"""
            maybe_set_guard_line = maybe_set_guard + "\n"
        else:
            maybe_set_guard_line = maybe_set_guard = ""

        if k is SchemaKind.functional:
            assert self.backend_index.dispatch_key in (
                DispatchKey.Meta,
                DispatchKey.CPU,
                DispatchKey.CUDA,
                DispatchKey.MPS,
                DispatchKey.CompositeExplicitAutograd,
            )
            return f"""{maybe_set_guard_line}
outputs_[output_idx] = create_out(sizes, strides, options);"""
        elif k is SchemaKind.inplace:
            return f"""{maybe_set_guard_line}
const auto& out = outputs_[output_idx].get();
check_inplace(out, sizes, options);"""
        elif k is SchemaKind.out:
            return f"""{maybe_set_guard_line}
const auto& out = outputs_[output_idx].get();
resize_out(out, sizes, strides, options);"""
        else:
            assert_never(k)
Esempio n. 9
0
    def gen_one(self, f: NativeFunction) -> Optional[str]:
        assert not f.manual_kernel_registration

        if (self.target is Target.REGISTRATION
                and not self.selector.is_native_function_selected(f)):
            return None

        # TODO: Now, there is something interesting going on here.  In the code below,
        # we generate CompositeExplicitAutograd implementations of functional and inplace
        # based on the out implementation.  But in fact, out is definable by
        # functional too (just not very efficiently), and this is honestly the
        # MORE likely situation for a backend implementor.  How do we pick?
        # Well, taking a page from Haskell type classes and default methods,
        # we could conceivably register a circular definition (out in terms
        # of functional, and functional in terms of out) and just require
        # someone to implement one or the other.  We'd have to do a little bit
        # of work to not register one of these "weak" definitions unless there
        # is a strong definition somewhere in the DAG!  So it's not implemented yet.
        if (self.backend_index.dispatch_key
                == DispatchKey.CompositeExplicitAutograd
                and f.func.kind() is SchemaKind.out):
            # Never generate a default implementation for out, that's what you
            # have to define as a backend implementor
            return None

        # Note [Direct dispatch bindings]
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Signature of the non-dispatched function we'll expose in a header
        # (e.g., at::cpu::add).  We don't generate methods (TODO: do this
        # when CPUTensor class is a thing); nor do we generate fallback
        # bindings for manual_cpp_binding functions.
        cpp_sig_group = CppSignatureGroup.from_native_function(
            f, method=False, fallback_binding=False)

        # Signature of the wrapper function we'll register to the dispatcher
        sig = NativeSignature(f.func, prefix="wrapper_")

        if self.target is Target.NAMESPACED_DECLARATION:
            result = f"TORCH_API {cpp_sig_group.signature.decl()};\n"
            if cpp_sig_group.faithful_signature is not None:
                result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n"
            return result

        elif self.target is Target.NAMESPACED_DEFINITION:

            def generate_defn(cpp_sig: CppSignature) -> str:
                return f"""
{cpp_sig.defn()} {{
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
}}
"""

            result = generate_defn(cpp_sig_group.signature)
            if cpp_sig_group.faithful_signature is not None:
                result += generate_defn(cpp_sig_group.faithful_signature)
            return result

        elif self.target is Target.ANONYMOUS_DEFINITION:

            k = f.func.kind()

            # Construct the body of the wrapper function with signature sig
            sig_body = []
            # We'll use context to keep track of any variables we've brought
            # into scope while generating code
            context: List[Union[Binding, Expr]] = list(sig.arguments())

            # Initialize the class corresponding to this structured
            # operator; feeding it the output argument(s) if it is known
            if self.backend_index.dispatch_key is DispatchKey.Meta:
                class_name = f"structured_{meta.name(self.g)}_meta_{k.name}"
                parent_class = f"at::meta::structured_{meta.name(self.g)}"
            elif (self.backend_index.dispatch_key is
                  DispatchKey.CompositeExplicitAutograd):
                # TODO: dedup this branch
                class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}"
                parent_class = f"at::meta::structured_{meta.name(self.g)}"
            else:
                metadata = self.backend_index.get_kernel(self.g)
                assert metadata is not None
                class_name = f"structured_{metadata.kernel}_{k.name}"
                parent_class = f"{self.cpp_namespace}::structured_{metadata.kernel}"

            if self.backend_index.device_guard:
                device_check_args = itertools.chain(
                    f.func.arguments.out, f.func.arguments.flat_positional)
                sig_body.append(
                    RegisterDispatchKey.gen_device_check(
                        f.device_check, list(device_check_args), sig.name()))

            if k is SchemaKind.functional:
                sig_body.append(f"{class_name} op;")
            elif k is SchemaKind.inplace:
                sig_body.append(f"{class_name} op(self);")
            elif k is SchemaKind.out:
                out_args_str = ", ".join(a.name for a in f.func.arguments.out)
                sig_body.append(f"{class_name} op({out_args_str});")

            # Translate the input native arguments into structured
            # arguments for the meta call
            meta_exprs = ", ".join(e.expr for e in translate(
                context, structured.meta_arguments(self.g), method=False))

            if self.g.out.precomputed:
                # If this function group has precomputed elements, the meta function
                # returns a struct containing them which must be saved so that it
                # can be unpacked when generating code to call the impl.
                sig_body.append(f"auto precompute = op.meta({meta_exprs});")

                # Put all of the contents of the precompute struct into the context
                # so that translate will be able to return the correct args for the
                # call to the impl.
                precomputed_values = [
                    *self.g.out.precomputed.replace.values(),
                    self.g.out.precomputed.add,
                ]
                for precomputed_elems in precomputed_values:
                    for arg in precomputed_elems:
                        context.append(
                            Expr(
                                expr=f"precompute.{arg.name}",
                                type=structured.argument_type(arg,
                                                              binds=arg.name),
                            ))

                # Add a use of the precompute struct so FB internal compilers don't
                # complain that there is an unused variable.
                sig_body.append("(void)precompute;")
            else:
                sig_body.append(f"op.meta({meta_exprs});")

            # After running meta, op.outputs_ is guaranteed to be valid;
            # add it to the context
            out_args = structured.out_arguments(self.g)
            maybe_star = "*" if k is SchemaKind.functional else ""
            for i, out_arg in enumerate(out_args):
                assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type
                context.append(
                    Expr(
                        expr=f"{maybe_star}op.outputs_[{i}]",
                        # TODO: Stop hardcoding that the output type is a Tensor.  Note
                        # that for the codegen here this is fine because outputs_ is
                        # hardcoded to be tensor already
                        type=NamedCType(out_arg.nctype.name,
                                        MutRefCType(BaseCType(tensorT))),
                    ))

            # With the expanded context, do the impl call (if not a meta
            # function)
            if self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd:
                # TODO: https://github.com/pytorch/pytorch/issues/53023
                out_sig_group = CppSignatureGroup.from_native_function(
                    self.g.out,
                    method=False,
                    fallback_binding=f.manual_cpp_binding)
                out_sig = out_sig_group.most_faithful_signature()
                api_name = out_sig.name()
                out_exprs = ", ".join(e.expr for e in translate(
                    context, out_sig.arguments(), method=False))
                # TODO: I think this means structured won't work with method
                # only functions (but maybe you're saved by faithful? iunno.)
                # NB: Originally I wrote this as an at::redispatch call, but
                # I got in trouble because that meant I needed a DispatchKeySet
                # in the wrapper function, which meant I needed a DispatchKeySet
                # in the DispatchKeyFunctions declarations, but the defined API
                # there does NOT permit a dispatch key set.  I think you can
                # probably unwind this by calling some function to do the TLS
                # fetch and get the DispatchKeySet when you don't have it, but
                # I didn't do it for this version
                sig_body.append(f"at::{api_name}({out_exprs});")
            elif self.backend_index.dispatch_key != DispatchKey.Meta:
                impl_exprs = ", ".join(e.expr for e in translate(
                    context, structured.impl_arguments(self.g), method=False))
                sig_body.append(f"op.impl({impl_exprs});")

            # Destructively return the final tensors
            # TODO: Do this in translate instead
            if k is SchemaKind.functional:
                if len(f.func.returns) == 1:
                    ret_expr = "std::move(op.outputs_[0]).take()"  # small optimization
                else:
                    moved = ", ".join(f"std::move(op.outputs_[{i}]).take()"
                                      for i in range(len(f.func.returns)))
                    ret_expr = f"std::make_tuple({moved})"
            elif k is SchemaKind.inplace:
                ret_expr = "self"
            elif k is SchemaKind.out:
                if len(f.func.returns) == 1:
                    ret_expr = f.func.arguments.out[0].name
                else:
                    refs = ", ".join(a.name for a in f.func.arguments.out)
                    ret_expr = f"std::forward_as_tuple({refs})"
            sig_body.append(f"return {ret_expr};")

            sig_body_str = "\n".join(sig_body)

            # For an overview of what this template code looks like, see
            # https://github.com/pytorch/rfcs/pull/9
            return f"""\
{self.gen_class(
f, k,
class_name=class_name,
parent_class=parent_class,
generate_super=self.g.out.structured_inherits is not None
)}

{sig.defn()} {{
{sig_body_str}
}}
"""

        elif self.target is Target.REGISTRATION:
            return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));'
        else:
            assert_never(self.target)
            # Silence mypy's "Missing return statement" error
            return None
Esempio n. 10
0
    def gen_unstructured(
            self,
            f: NativeFunction,
            g: Optional[NativeFunctionsGroup] = None) -> Optional[str]:
        with native_function_manager(f):
            inplace_meta = False
            gets_out_inplace_wrapper = False
            if not self.backend_index.has_kernel(f):
                if (self.backend_index.dispatch_key == DispatchKey.Meta
                        and f.func.kind() is SchemaKind.inplace and
                        # Defer to composites for meta implementation
                        not f.has_composite_kernel and
                        # Inplace list operations are not supported
                        len(f.func.returns) == 1):
                    inplace_meta = True
                elif (not self.backend_index.use_out_as_primary
                      and g is not None and gets_generated_out_inplace_wrapper(
                          f, g, self.backend_index)):
                    # We want to generate inplace/out wrappers, that don't have a kernel for the backend.
                    gets_out_inplace_wrapper = True
                else:
                    return None
            if f.manual_kernel_registration:
                return None

            if (self.target is Target.REGISTRATION
                    and not self.selector.is_native_function_selected(f)):
                return None

            sig = self.wrapper_kernel_sig(f)

            name = sig.name()
            returns_type = sig.returns_type().cpp_type()
            args = sig.arguments()
            args_str = ", ".join(a.defn() for a in args)

            # See Note [Direct dispatch bindings]
            cpp_sig_group = CppSignatureGroup.from_native_function(
                f, method=False, fallback_binding=False)

            if self.target is Target.NAMESPACED_DECLARATION:
                result = f"TORCH_API {cpp_sig_group.signature.decl()};\n"
                if cpp_sig_group.faithful_signature is not None:
                    result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n"
                return result
            elif self.target is Target.NAMESPACED_DEFINITION:

                def generate_defn(cpp_sig: CppSignature) -> str:
                    return f"""
{cpp_sig.defn()} {{
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
}}
"""

                result = generate_defn(cpp_sig_group.signature)
                if cpp_sig_group.faithful_signature is not None:
                    result += generate_defn(cpp_sig_group.faithful_signature)
                return result
            elif self.target is Target.ANONYMOUS_DEFINITION:
                # short circuit for inplace_meta
                if inplace_meta:
                    assert f.func.arguments.self_arg is not None
                    self_arg_name = f.func.arguments.self_arg.argument.name
                    # TODO: handle in place on tensor list
                    return f"""
{returns_type} {name}({args_str}) {{
  TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(),
    "Cannot inplace into non-meta tensor with meta tensor argument");
  return {self_arg_name};
}}
"""

                # short circuit for generated inplace/out wrappers
                if gets_out_inplace_wrapper:
                    return self.gen_out_inplace_wrapper(f, g)

                metadata = self.backend_index.get_kernel(f)
                if metadata is None:
                    return None
                if self.class_method_name is None:
                    impl_name = f"{self.cpp_namespace}::{metadata.kernel}"
                else:
                    impl_name = f"{self.cpp_namespace}::{self.class_method_name}::{metadata.kernel}"

                args_exprs_str = ", ".join(a.name for a in args)

                device_check = "  // No device check\n"
                # Backends that require device guards presumably also require device checks.
                if self.backend_index.device_guard:
                    device_check_args = itertools.chain(
                        f.func.arguments.out, f.func.arguments.flat_positional)
                    device_check = RegisterDispatchKey.gen_device_check(
                        f.device_check, list(device_check_args), name)

                device_guard = "// DeviceGuard omitted"  # default
                if f.device_guard and self.backend_index.device_guard:
                    has_tensor_options = any(
                        isinstance(a, TensorOptionsArguments)
                        for a in f.func.arguments.non_out)
                    if has_tensor_options:
                        # kernel is creating a tensor
                        device_guard = """
  const DeviceGuard device_guard(device_or_default(device));"""

                        # CUDA requires special handling
                        if is_cuda_dispatch_key(
                                self.backend_index.dispatch_key):
                            device_guard = (
                                f"globalContext().lazyInitCUDA();\n{device_guard}"
                            )
                    else:
                        # kernel is operating on existing tensors

                        # There is precedence for which argument we use to do
                        # device guard.  This describes the precedence order.
                        self_arg = ([
                            f.func.arguments.self_arg.argument
                        ] if f.func.arguments.self_arg is not None else [])
                        candidate_args = itertools.chain(
                            self_arg,
                            f.func.arguments.out,
                            f.func.arguments.flat_positional,
                        )

                        # Only tensor like arguments are eligible
                        device_of = next(
                            (f"{a.name}" for a in candidate_args
                             if a.type.is_tensor_like()),
                            None,
                        )
                        if device_of is not None:
                            device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));"

                return f"""\
namespace {{

{returns_type} {name}({args_str}) {{
  {device_check}

  {device_guard}
  return {impl_name}({args_exprs_str});
}}

}} // anonymous namespace
"""

            elif self.target is Target.REGISTRATION:
                if f.manual_kernel_registration or self.skip_dispatcher_op_registration:
                    return None
                else:
                    payload = f"TORCH_FN({name})"
                    return f'm.impl("{f.func.name}",\n{payload});\n'
            else:
                assert_never(self.target)
Esempio n. 11
0
def argument(
    a: Union[Argument, TensorOptionsArguments, SelfArgument],
    *,
    cpp_no_default_args: Set[str],
    method: bool,
    faithful: bool,
    has_tensor_options: bool,
) -> List[Binding]:
    def sub_argument(
        a: Union[Argument, TensorOptionsArguments, SelfArgument]
    ) -> List[Binding]:
        return argument(
            a,
            cpp_no_default_args=cpp_no_default_args,
            method=method,
            faithful=faithful,
            has_tensor_options=has_tensor_options,
        )

    if isinstance(a, Argument):
        binds: ArgName
        if a.name == "memory_format" and has_tensor_options:
            binds = SpecialArgName.possibly_redundant_memory_format
        else:
            binds = a.name
        default: Optional[str] = None
        if a.name not in cpp_no_default_args and a.default is not None:
            default = default_expr(a.default, a.type)
        return [
            Binding(
                nctype=argument_type(a, binds=binds),
                name=a.name,
                default=default,
                argument=a,
            )
        ]
    elif isinstance(a, TensorOptionsArguments):
        if faithful:
            return (
                sub_argument(a.dtype)
                + sub_argument(a.layout)
                + sub_argument(a.device)
                + sub_argument(a.pin_memory)
            )
        else:
            default = None
            # Enforced by NativeFunction.__post_init__
            assert "options" not in cpp_no_default_args
            if all(x.default == "None" for x in a.all()):
                default = "{}"
            elif a.dtype.default == "long":
                default = "at::kLong"  # TODO: this is wrong
            return [
                Binding(
                    nctype=NamedCType("options", BaseCType(tensorOptionsT)),
                    name="options",
                    default=default,
                    argument=a,
                )
            ]
    elif isinstance(a, SelfArgument):
        if method:
            # Caller is responsible for installing implicit this in context!
            return []
        else:
            return sub_argument(a.argument)
    else:
        assert_never(a)