def gen_out_wrapper(g: ExternalBackendFunctionsGroup) -> Optional[str]:
            dispatcher_sig = DispatcherSignature.from_schema(
                g.out.native_function.func)
            name = dispatcher_sig.name()

            dispatcher_order_args = dispatcher.jit_arguments(
                g.out.native_function.func)
            tensors = [
                a for a in dispatcher_order_args
                if a.type == BaseType(BaseTy.Tensor)
            ]
            print_args_str = ''.join(
                [f' << " {a.name}=" << {a.name}.toString()' for a in tensors])

            func_name = f'AtenXlaTypeDefault::{name}'
            functional_result_name = f'{name}_tmp'
            return_names = cpp.return_names(g.out.native_function)
            if len(return_names) > 1:
                updates = '\n  '.join(
                    f'bridge::XlaUpdateTensors({{{ret_name}}}, {{std::get<{i}>({functional_result_name})}}, {{0}});'
                    for i, ret_name in enumerate(return_names))
                returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_names)})'
            else:
                ret_name = return_names[0]
                updates = f'bridge::XlaUpdateTensors({{{ret_name}}}, {{{functional_result_name}}}, {{0}});'
                returns = ret_name

            functional_sig = DispatcherSignature.from_schema(
                g.functional.native_function.func)

            return f"""\
Example #2
0
def emit_inplace_functionalization_body(
        f: NativeFunction, functional_op: Optional[NativeFunction]) -> str:
    # mutation case
    assert (modifies_arguments(f))

    dispatcher_sig = DispatcherSignature.from_schema(f.func)

    keyset = 'dispatchKeySet & c10::after_func_keyset'
    return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type()

    unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args(
        dispatcher_sig)

    maybe_return = '' if len(f.func.returns) == 0 else 'return '
    sync_tensor_args = '\n      '.join(
        mapMaybe(
            lambda arg: f'at::functionalization::impl::sync({arg.name});'
            if arg.type.is_tensor_like() else None, f.func.arguments.flat_all))

    if functional_op is None:
        # We can't functionalize this inplace op, since we don't know what the corresponding functional op is.
        inplace_exprs = [keyset] + [
            e.expr for e in translate(
                unwrapped_args_ctx, dispatcher_sig.arguments(), method=False)
        ]
        warn_str = "Note: the functionalization pass encountered an operator ({}) that it could not functionalize, \
because it couldn't find an out-of-place equivalent of the operator to call. \
Instead, it's calling the inplace/view operator directly. \
If this causes problems in your program, consider upstreaming the out-of-place op to PyTorch.".format(
            str(f.func.name))

        return f"""
      if (c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {{
          TORCH_WARN("{warn_str}");
      }}
      {sync_tensor_args}
      {unwrap_tensor_args_str}
      at::AutoDispatchSkipFunctionalize guard;
      // Redispatch as normally otherwise, since XLA has its own lowerings for special inplace ops.
      {maybe_return}at::_ops::{f.func.name.unambiguous_name()}::redispatch({', '.join(inplace_exprs)});
"""
    # call the out-of-place variant of the op
    functional_sig = DispatcherSignature.from_schema(functional_op.func)
    functional_exprs = [keyset] + [
        e.expr for e in translate(
            unwrapped_args_ctx, functional_sig.arguments(), method=False)
    ]

    mutable_input_post_processing = '\n'.join([
        f"""
      auto {a.name}_functional = at::functionalization::impl::unsafeGetFunctionalWrapper({a.name});
      {a.name}_functional->replace_(tmp_output);
      {a.name}_functional->commit_update();"""
        for a in f.func.arguments.flat_non_out
        if a.annotation and a.annotation.is_write and a.type.is_tensor_like()
    ])

    return f"""
Example #3
0
def emit_trace_body(f: NativeFunction) -> List[str]:
    trace_body: List[str] = []

    trace_body.append(format_prerecord_trace(f))
    trace_body.append(declare_returned_variables(f))

    dispatcher_sig = DispatcherSignature.from_schema(f.func)
    dispatcher_exprs = dispatcher_sig.exprs()

    # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance.
    # See Note [Plumbing Keys Through The Dispatcher] for details.
    dispatch_key_set = 'ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Tracer)'
    redispatch_args = ', '.join([dispatch_key_set] +
                                [a.expr for a in dispatcher_exprs])

    assign_return_values = f'{tie_return_values(f)} = ' \
                           if f.func.kind() == SchemaKind.functional and f.func.returns else ''

    # Note that this calls the slow, dispatching variants of manual_cpp_binding ops.
    # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal.
    trace_body.append(
        TRACE_DISPATCH.substitute(
            assign_return_values=assign_return_values,
            unambiguous_name=f.func.name.unambiguous_name(),
            unpacked_args=redispatch_args,
        ))

    trace_body.append(format_postrecord_trace(f))
    if f.func.returns:
        trace_body.append(f'return {get_return_value(f)};')
    return trace_body
Example #4
0
        def generate_defn(faithful: bool) -> str:
            dispatcher_sig = DispatcherSignature.from_schema(f.func)

            if faithful and sig_group.faithful_signature is not None:
                sig = sig_group.faithful_signature
            else:
                sig = sig_group.signature

            dispatcher_exprs = translate(sig.arguments(),
                                         dispatcher_sig.arguments())
            if self.is_redispatching_fn:
                dispatcher_exprs_str = ', '.join(
                    ['dispatchKeySet'] + [a.expr for a in dispatcher_exprs])
                dispatcher_call = 'redispatch'
            else:
                dispatcher_exprs_str = ', '.join(a.expr
                                                 for a in dispatcher_exprs)
                dispatcher_call = 'call'

            static_dispatch_block = static_dispatch(
                f, sig, method=False, backend=self.static_dispatch_backend)
            if static_dispatch_block is None:
                return f"""
// aten::{f.func}
{sig.defn(is_redispatching_fn=self.is_redispatching_fn)} {{
    static auto op = c10::Dispatcher::singleton()
        .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
        .typed<{dispatcher_sig.type()}>();
    return op.{dispatcher_call}({dispatcher_exprs_str});
}}
"""
            else:
                return f"""
Example #5
0
        def generate_defn(faithful: bool) -> str:
            dispatcher_sig = DispatcherSignature.from_schema(f.func)

            if faithful:
                sig = sig_group.faithful_signature
                assert sig is not None
            else:
                sig = sig_group.signature

            dispatcher_exprs = translate(sig.arguments(),
                                         dispatcher_sig.arguments(),
                                         method=True)
            dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs)

            static_dispatch_block = static_dispatch(
                f, sig, method=True, backend=self.static_dispatch_backend)
            if static_dispatch_block is None:
                return f"""
// aten::{f.func}
{sig.defn(prefix="Tensor::")} const {{
    static auto op = c10::Dispatcher::singleton()
        .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
        .typed<{dispatcher_sig.type()}>();
    return op.call({dispatcher_exprs_str});
}}
"""
            else:
                return f"""
def unwrap_tensor_args(sig: DispatcherSignature, *,
                       is_view_op: bool) -> Tuple[str, List[Binding]]:
    context: List[Binding] = []
    unwrapped_tensor_args: List[str] = []
    for arg in sig.arguments():
        if is_tensor_like(arg.argument):
            # for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
            unwrapped_name = f'{arg.name}_'
            # For most ops, the functionalization needs to sync any pending updates on the input tensors
            # before calling the operator, since otherwise the operator will act on stale data.
            # For view ops though, we can continue to defer syncing until the tensor is used by
            # a non-view operator.
            maybe_sync_input = '' if is_view_op else f'at::functionalization::impl::sync({arg.name});'
            unwrapped_tensor_args.append(f"""
      {arg.nctype.remove_const_ref().cpp_type()} {unwrapped_name};
      if (at::functionalization::impl::isFunctionalTensor({arg.name})) {{
        {maybe_sync_input}
        {unwrapped_name} = at::functionalization::impl::from_functional_tensor({arg.name});
      }} else {{
        {unwrapped_name} = {arg.name};
      }}""")
            context.append(arg.with_name(unwrapped_name))
        else:
            # for non-tensor inputs, we want to pass them directly into the redispatch calls.
            context.append(arg)
    unwrap_tensor_args_str = '\n      '.join(unwrapped_tensor_args)
    return unwrap_tensor_args_str, context
Example #7
0
def gen_unstructured_external(f: ExternalBackendFunction) -> List[str]:
    # XLA appears to have used the dispatcher convention to write their kernel signatures,
    # probably because they based their signatures off of our RegistrationDeclarations.h
    dispatcher_sig = DispatcherSignature.from_schema(f.native_function.func)
    if f.metadata is not None:
        # Only generate declarations for operators that xla has defined in the yaml
        return [f"static {dispatcher_sig.decl()};"]
    else:
        return []
Example #8
0
    def gen_definition(self, f: NativeFunction) -> str:
        unambiguous_name = self.unambiguous_function_name(f)
        args = dispatcher.arguments(f.func)
        sig = DispatcherSignature.from_schema(f.func)

        return deindent(f"""\
            {sig.defn(unambiguous_name)} {{
              return {self.invocation(f)};
            }}\
        """)
Example #9
0
    def __call__(self, f: NativeFunction) -> Optional[str]:
        if str(f.func.name.name).endswith('_like') or str(
                f.func.name.name).startswith('new_'):
            return None

        name = native.name(f.func)
        native_sig = NativeSignature(f.func)

        if not any(
                isinstance(a.argument, TensorOptionsArguments)
                for a in native_sig.arguments()):
            return None

        native_tensor_args = [
            a for a in native_sig.arguments()
            if isinstance(a.argument, Argument)
            and a.argument.type.is_tensor_like()
        ]

        dispatcher_sig = DispatcherSignature.from_schema(f.func)

        sig: Union[NativeSignature, DispatcherSignature]
        sig = dispatcher_sig
        dispatcher_exprs = dispatcher_sig.exprs()
        dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"

        if self.target is Target.DEFINITION:
            # I don't think there's actually a good reason to generate
            # these two cases differently
            # The first case could probably be improved though- it calls computeDispatchKeySet(),
            # which looks at TLS dispatch keys- there should not be any by the time we reach backend select.
            if native_tensor_args:
                tensor_args = ', '.join(a.name for a in native_tensor_args)
                compute_dk = f"""\
DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args});
  DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect);
  DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);"""
            else:
                compute_dk = f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});"
            return f"""\
// aten::{f.func}
C10_ALWAYS_INLINE
{sig.defn(name)} {{
  static auto op = c10::Dispatcher::singleton()
    .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
    .typed<{dispatcher_sig.type()}>();
  {compute_dk}
  return op.redispatch(_dk, {', '.join(a.expr for a in dispatcher_exprs)});
}}
"""
        elif self.target is Target.REGISTRATION:
            return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));"""
        else:
            assert_never(self.target)
Example #10
0
        def generate_defn(faithful: bool) -> str:
            if faithful:
                sig = sig_group.faithful_signature
                assert sig is not None
            else:
                sig = sig_group.signature

            target_sig = DispatcherSignature.from_schema(f.func)
            exprs = translate(sig.arguments(), target_sig.arguments())
            exprs_str = ', '.join(['dispatchKeySet'] + [a.expr for a in exprs])

            return f"""
def gen_composite_view_copy_kernel(
        g: NativeFunctionsViewGroup) -> Optional[str]:

    if g.view_copy is None:
        return None
    # view_copy is a native signature, since we're generating an at::native:: kernel
    view_copy_sig = NativeSignature(g.view_copy.func)
    # view is a dispatcher signature, since we're calling into the at::_ops API
    view_sig = DispatcherSignature(g.view.func)

    view_api_name = g.view.func.name.unambiguous_name()
    exprs = ', '.join([
        e.expr
        for e in translate(view_copy_sig.arguments(), view_sig.arguments())
    ])

    # view ops today always return either a Tensor or a list of Tensors
    assert len(g.view.func.returns) == 1
    assert g.view.func.returns[0].type == BaseType(BaseTy.Tensor) \
           or g.view.func.returns[0].type == ListType(BaseType(BaseTy.Tensor), None)

    if g.view.func.returns[0].type == BaseType(BaseTy.Tensor):
        return_cloned_output = '''\
  return output.clone();'''
    else:
        # If the return type is a list, we need to clone each tensor in the list.
        return_cloned_output = f'''\
  {view_copy_sig.returns_type().cpp_type()} out_clone;
  for (const auto i : c10::irange(output.size())) {{
    out_clone.push_back(output[i].clone());
  }}
  return out_clone;'''

    # The default generated composite kernel for {view}_copy() operators just clones
    # the input tensor, and runs the underlying view on the clone.
    return f"""
def unwrap_tensor_args(sig: DispatcherSignature) -> Tuple[str, List[Binding]]:
    context: List[Binding] = []
    unwrapped_tensor_args: List[str] = []
    for arg in sig.arguments():
        if is_tensor_like(arg.argument):
            # for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
            unwrapped_name = f'{arg.name}_'
            unwrapped_tensor_args.append(
                f'auto {unwrapped_name} = at::functionalization::impl::from_functional_tensor({arg.name});')
            context.append(arg.with_name(unwrapped_name))
        else:
            # for non-tensor inputs, we want to pass them directly into the redispatch calls.
            context.append(arg)
    unwrap_tensor_args_str = '\n      '.join(unwrapped_tensor_args)
    return unwrap_tensor_args_str, context
Example #13
0
def emit_inplace_or_view_body(
        fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]:
    f = fn.func
    inplace_view_body: List[str] = []

    dispatcher_sig = DispatcherSignature.from_schema(f.func)
    dispatcher_exprs = dispatcher_sig.exprs()

    # code-generated InplaceOrView kernels plumb and recompute dispatch keys directly through the kernel for performance.
    # See Note [Plumbing Keys Through The Dispatcher] for details.
    dispatch_key_set = 'ks & c10::after_InplaceOrView_keyset'
    redispatch_args = ', '.join([dispatch_key_set] +
                                [a.expr for a in dispatcher_exprs])

    # Note that this calls the slow, dispatching variants of manual_cpp_binding ops.
    # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal.
    sig_group = CppSignatureGroup.from_native_function(
        f, method=False, fallback_binding=f.manual_cpp_binding)
    if sig_group.faithful_signature is not None:
        api_name = sig_group.faithful_signature.name()
    else:
        api_name = sig_group.signature.name()
    inplace_view_body.append(THROW_IF_VARIABLETYPE_ON)
    if modifies_arguments(f):  # inplace op
        inplace_view_body.append(
            INPLACE_REDISPATCH.substitute(
                api_name=api_name,
                unpacked_args=redispatch_args,
            ))
        for r in cpp.return_names(f):
            inplace_view_body.append(f'increment_version({r});')
    else:
        assert (get_view_info(fn) is not None)
        inplace_view_body.append(
            VIEW_REDISPATCH.substitute(
                assign_return_values='auto ' + TMP_VAR + ' = ',
                api_name=api_name,
                unpacked_args=redispatch_args,
            ))
        call, rhs_value = emit_view_body(fn, TMP_VAR)
        inplace_view_body.append(call)
        assert rhs_value is not None
        inplace_view_body.append(
            ASSIGN_RETURN_VALUE.substitute(return_values=tie_return_values(f),
                                           rhs_value=rhs_value))
    if f.func.returns:
        inplace_view_body.append(f'return {get_return_value(f)};')
    return inplace_view_body
Example #14
0
    def emit_dispatch_call(f: NativeFunction, input_base: str, unpacked_args: Sequence[str]) -> str:
        """ Dispatch call via function in a namespace or method on Tensor."""
        dispatcher_sig = DispatcherSignature.from_schema(f.func)
        dispatcher_exprs = dispatcher_sig.exprs()

        # code-generated autograd kernels plumb and recompute dispatch keys directly through the kernel for performance.
        # Ops also always have a function variant of the redispatch API.
        # See Note [Plumbing Keys Through The Dispatcher] for details.
        dispatch_key_set = 'ks & c10::after_autograd_keyset'
        call = CALL_REDISPATCH.substitute(
            api_name=cpp.name(
                f.func,
                faithful_name_for_out_overloads=True,
            ),
            unpacked_args=[dispatch_key_set] + list(unpacked_args))
        return call
 def emit_registration_helper(f: NativeFunction, *, is_view: bool) -> str:
     if is_view and f.has_composite_implicit_autograd_kernel:
         metadata = composite_implicit_autograd_index.get_kernel(f)
         assert metadata is not None
         native_api_name = metadata.kernel
         sig = DispatcherSignature.from_schema(f.func)
         # Note [Composite view ops in the functionalization pass]
         # We don't need to worry about implemententing functionalization kernels for views with
         # CompositeImplicitAutograd kernels, because we can just decompose them into their base operators.
         # We can't just opt the entire Functionalization dispatch key into the composite keyset though,
         # because we don't want to decompose non-view ops that are composite, like `at::ones`.
         registration_str = f'static_cast<{sig.ptr_type()}>(at::native::{native_api_name})'
     else:
         # non-composite view ops (and inplace ops) get a normal registration.
         registration_str = f'TORCH_FN(functionalization::{wrapper_name(f.func)})'
     return f'm.impl("{f.func.name}", {registration_str});'
 def emit_definition_helper(f: NativeFunction) -> Optional[str]:
     if not needs_functionalization(selector, f):
         return None
     if f.is_view_op and f.has_composite_implicit_autograd_kernel:
         # See Note [Composite view ops in the functionalization pass]
         return None
     # order is important here, ops that are both views and mutations should hit the view path.
     if f.is_view_op:
         # Every view op is expected to have a functional counterpart (e.g. transpose_() -> transpose())
         assert functional_op is not None
         body_str = emit_view_functionalization_body(f, functional_op)
     else:
         # inplace op
         assert modifies_arguments(f)
         body_str = emit_inplace_functionalization_body(f, functional_op)
     sig = DispatcherSignature.from_schema(f.func)
     return f"""
def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]:
    context: List[Binding] = []
    unwrapped_tensor_args: List[str] = []
    for arg in sig.arguments():
        if is_tensor_like(arg.argument):
            # for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
            # for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
            a_ = arg.name
            unwrapped_name = f'{arg.name}_meta'
            unwrapped_tensor_args.append(
                f"auto {unwrapped_name} = at::native::empty_strided_meta({a_}.sizes(), {a_}.strides(), \
/*dtype=*/c10::make_optional({a_}.scalar_type()), /*layout=*/c10::make_optional({a_}.layout()), \
/*device=*/c10::make_optional(c10::Device(kMeta)), /*pin_memory=*/c10::nullopt);"
            )
            context.append(arg.with_name(unwrapped_name))
        else:
            # for non-tensor inputs, we want to pass them directly into the redispatch calls.
            context.append(arg)
    unwrap_tensor_args_str = '\n        '.join(unwrapped_tensor_args)
    return unwrap_tensor_args_str, context
    def emit_registration_helper(f: NativeFunction) -> Optional[str]:
        # Note: for now, this logic is meant to avoid registering functionalization kernels for mobile.
        # At some point, Vulkan we'll want to use functionalization and we'll need to change this.
        if not needs_functionalization(selector, f):
            return None
        if f.is_view_op and f.has_composite_implicit_autograd_kernel:
            metadata = composite_implicit_autograd_index.get_kernel(f)
            assert metadata is not None
            native_api_name = metadata.kernel
            sig = DispatcherSignature.from_schema(f.func)
            # Note [Composite view ops in the functionalization pass]
            # We don't need to worry about implemententing functionalization kernels for views with
            # CompositeImplicitAutograd kernels, because we can just decompose them into their base operators.
            # We can't just opt the entire Functionalization dispatch key into the composite keyset though,
            # because we don't want to decompose non-view ops that are composite, like `at::ones`.
            registration_str = f'static_cast<{sig.ptr_type()}>(at::native::{native_api_name})'
        else:
            registration_str = f'TORCH_FN(functionalization::{wrapper_name(f.func)})'

        return f'm.impl("{f.func.name}", {registration_str});'
Example #19
0
        def generate_defn(faithful: bool) -> str:
            if faithful:
                sig = sig_group.faithful_signature
                assert sig is not None
            else:
                sig = sig_group.signature

            target_sig = DispatcherSignature.from_schema(f.func)
            exprs = translate(sig.arguments(), target_sig.arguments(), method=True)
            exprs_str = ', '.join([e.expr for e in exprs])

            static_dispatch_block = static_dispatch(f, sig, method=True, backend_index=self.static_dispatch_backend_index)
            if static_dispatch_block is None:
                return f"""
// aten::{f.func}
inline {sig.defn(prefix="Tensor::")} const {{
    return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
}}
"""
            else:
                return f"""
Example #20
0
        def generate_defn(faithful: bool) -> str:
            if faithful:
                sig = sig_group.faithful_signature
                assert sig is not None
            else:
                sig = sig_group.signature

            # See Note [The ATen Operators API]
            target_sig = DispatcherSignature.from_schema(f.func)
            exprs = translate(sig.arguments(), target_sig.arguments())
            exprs_str = ', '.join([e.expr for e in exprs])

            static_dispatch_block = static_dispatch(f, sig, method=False, backend_index=self.static_dispatch_backend_index)
            if static_dispatch_block is None:
                return f"""
// aten::{f.func}
TORCH_API inline {sig.decl()} {{
    return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
}}
"""
            else:
                return f"""
Example #21
0
    def __call__(self, f: NativeFunction) -> Optional[str]:
        sig = DispatcherSignature.from_schema(f.func)
        name = f.func.name.unambiguous_name()
        call_method_name = 'call'
        redispatch_method_name = 'redispatch'

        if self.target is Target.DECLARATION:
            # Note [The ATen Operators API]
            # The ATen Operators API lives in the at::_ops namespace, and contains compile-time
            # metadata about each operator + entry points into the Dispatcher.
            # The C++ function, method, and redispatch API's are all implemented as wrappers
            # into various bits of the structs defined here.
            #
            # Important characteristics about the Operators API:
            # (1) It follows the Dispatcher API.
            #     This is kind of necessary to avoid overhead.
            #     For example: if it followed the C++ API, then all of the faithful C++ factory functions
            #     would need to wrap their arguments into TensorOptions only to unwrap them again.
            # (2) Overload names are disambiguated.
            #     This is helpful for pytorch extenders who would like to decltype() an aten operator,
            #     that has overloads, e.g. decltype(at::_ops::mul_Tensor::call)
            # (3) No argument defaulting is allowed.
            #     This is more of an implementation detail to avoid #include cycles,
            #     since TensorBody.h (which defines the Tensor class) needs to include this file.
            # (4) manual_cpp_bindings and faithful names are not included in the API.
            #     This applies to stuff like __dispatch__is_complex(), and add_outf().
            #     These aren't "real aten ops", they're just additional functions provided by the C++ API.
            #     They're implemented as wrappers in Functions.h that call into the actual operators
            #     defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call().
            #     This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher.
            return f"""
struct TORCH_API {name} {{
  using schema = {sig.type()};
  using ptr_schema = schema*;
  // See Note [static constexpr char* members for windows NVCC]
  STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{str(f.func.name.name)}")
  STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}")
  STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))})
  static {sig.defn(name=call_method_name, is_redispatching_fn=False)};
  static {sig.defn(name=redispatch_method_name, is_redispatching_fn=True)};
}};"""
        elif self.target is Target.DEFINITION:
            defns = f"""
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, name, "aten::{str(f.func.name)}")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, overload_name, "{f.func.name.overload_name}")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))})"""

            for is_redispatching_fn in [False, True]:
                if is_redispatching_fn:
                    dispatcher_exprs_str = ', '.join(['dispatchKeySet'] + [a.name for a in sig.arguments()])
                    dispatcher_call = 'redispatch'
                    method_name = f'{name}::{redispatch_method_name}'
                else:
                    dispatcher_exprs_str = ', '.join([a.name for a in sig.arguments()])
                    dispatcher_call = 'call'
                    method_name = f'{name}::{call_method_name}'

                defns += f"""
// aten::{f.func}
{sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{
    static auto op = c10::Dispatcher::singleton()
        .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
        .typed<{sig.type()}>();
    return op.{dispatcher_call}({dispatcher_exprs_str});
}}
"""
            return defns
        else:
            assert_never(self.target)
Example #22
0
 def create_decl(f: NativeFunction) -> str:
     with native_function_manager(f):
         return DispatcherSignature.from_schema(f.func).decl()
def emit_inplace_functionalization_body(
        f: NativeFunction, functional_op: Optional[NativeFunction]) -> str:
    # mutation case
    assert (modifies_arguments(f))

    dispatcher_sig = DispatcherSignature.from_schema(f.func)

    keyset = 'dispatchKeySet & c10::after_func_keyset'
    return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type()

    unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args(
        dispatcher_sig)

    maybe_return = '' if len(f.func.returns) == 0 else 'return '
    sync_tensor_args = '\n      '.join(
        mapMaybe(
            lambda arg: f'at::functionalization::impl::sync({arg.name});'
            if arg.type.is_tensor_like() else None, f.func.arguments.flat_all))

    # Note [functionalizating copy_() and not preserving strides]
    # copy_() can't be functionalized, since there doesn't exist an out-of-place variant.
    # We could add one, but that would be sub-optimal for functorch: copy() would need to allocate a fresh tensor.
    # This may seem like a large hack for one optimization, but copy_() is one of the most common inplace operators.
    # Instead, we can replace `self.copy_(src)` with `src.to(self).expand_as(self)`.
    # This maintains the exact same semantics, EXCEPT that we don't preserve the strides from `self`.
    # This seems like a reasonable tradeoff, for a few reasons:
    # - mutation removal is only used by functorch, and not by Vulkan or XLA. Functorch already doesn't preserve strides.
    # - There are actually a few other places where the functionalization pass currently doesn't support strides:
    #   calls to slice/diagonal_scatter don't currently preserve the strides of their inputs (but maybe we should fix this).
    if str(f.func.name) == 'copy_':
        exprs = [keyset] + [a.name for a in unwrapped_args_ctx]
        functional_call_str = f"""\
            auto tmp_intermediate = at::_ops::to_other::redispatch({keyset}, src_, self_, non_blocking, false, c10::nullopt);
            tmp_output = at::_ops::expand_as::redispatch({keyset}, tmp_intermediate, self_);"""
    elif functional_op is None:
        # We can't functionalize this inplace op, since we don't know what the corresponding functional op is.
        inplace_exprs = [keyset] + [
            e.expr for e in translate(
                unwrapped_args_ctx, dispatcher_sig.arguments(), method=False)
        ]
        warn_str = "Note: the functionalization pass encountered an operator ({}) that it could not functionalize, \
because it couldn't find an out-of-place equivalent of the operator to call. \
Instead, it's calling the inplace/view operator directly. \
If this causes problems in your program, consider upstreaming the out-of-place op to PyTorch.".format(
            str(f.func.name))

        return f"""
      if (c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {{
          TORCH_WARN("{warn_str}");
      }}
      {sync_tensor_args}
      {unwrap_tensor_args_str}
      at::AutoDispatchSkipFunctionalize guard;
      // Redispatch as normally otherwise, since XLA has its own lowerings for special inplace ops.
      {maybe_return}at::_ops::{f.func.name.unambiguous_name()}::redispatch({', '.join(inplace_exprs)});
"""
    else:
        # call the out-of-place variant of the op
        functional_sig = DispatcherSignature.from_schema(functional_op.func)
        functional_exprs = [keyset] + [
            e.expr for e in translate(
                unwrapped_args_ctx, functional_sig.arguments(), method=False)
        ]
        functional_call_str = \
            f"tmp_output = at::_ops::{functional_op.func.name.unambiguous_name()}::redispatch({', '.join(functional_exprs)});"

    mutable_input_post_processing = '\n'.join([
        f"""
      auto {a.name}_functional = at::functionalization::impl::unsafeGetFunctionalWrapper({a.name});
      {a.name}_functional->replace_(tmp_output);
      {a.name}_functional->commit_update();"""
        for a in f.func.arguments.flat_non_out
        if a.annotation and a.annotation.is_write and a.type.is_tensor_like()
    ])

    return f"""
Example #24
0
    def gen_unstructured(self, f: NativeFunction) -> Optional[str]:
        inplace_meta = False
        if self.dispatch_key not in f.dispatch:
            if (self.dispatch_key == DispatchKey.Meta
                    and f.func.kind() is SchemaKind.inplace and
                    # Defer to composites for meta implementation
                    DispatchKey.CompositeImplicitAutograd not in f.dispatch
                    and DispatchKey.CompositeExplicitAutograd not in f.dispatch
                    and
                    # Inplace list operations are not supported
                    len(f.func.returns) == 1):
                inplace_meta = True
            else:
                return None
        if f.manual_kernel_registration:
            return None

        if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(
                f):
            return None

        sig = NativeSignature(f.func, prefix='wrapper_')

        name = sig.name()
        returns_type = sig.returns_type().cpp_type()
        args = sig.arguments()
        args_str = ', '.join(a.defn() for a in args)

        # See Note [Direct dispatch bindings]
        cpp_sig_group = CppSignatureGroup.from_native_function(
            f, method=False, fallback_binding=False)

        if self.target is Target.NAMESPACED_DECLARATION:
            result = f"TORCH_API {cpp_sig_group.signature.decl()};\n"
            if cpp_sig_group.faithful_signature is not None:
                result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n"
            return result
        elif self.target is Target.NAMESPACED_DEFINITION:

            def generate_defn(cpp_sig: CppSignature) -> str:
                return f"""
{cpp_sig.defn()} {{
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
}}
"""

            result = generate_defn(cpp_sig_group.signature)
            if cpp_sig_group.faithful_signature is not None:
                result += generate_defn(cpp_sig_group.faithful_signature)
            return result
        elif self.target is Target.ANONYMOUS_DEFINITION:
            # short circuit for inplace_meta
            if inplace_meta:
                assert f.func.arguments.self_arg is not None
                self_arg_name = f.func.arguments.self_arg.argument.name
                # TODO: handle in place on tensor list
                return f"""
{returns_type} {name}({args_str}) {{
  TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(),
    "Cannot inplace into non-meta tensor with meta tensor argument");
  return {self_arg_name};
}}
"""

            impl_name = f"at::native::{f.dispatch[self.dispatch_key]}"

            args_exprs_str = ', '.join(a.name for a in args)

            device_guard = "// DeviceGuard omitted"  # default

            if f.device_guard and is_cuda_dispatch_key(self.dispatch_key):
                has_tensor_options = any(
                    isinstance(a.argument, TensorOptionsArguments)
                    for a in args)
                if has_tensor_options:
                    # kernel is creating a tensor
                    device_guard = """globalContext().lazyInitCUDA();
  const DeviceGuard device_guard(device_or_default(device));"""
                else:
                    # kernel is operating on existing tensors

                    # There is precedence for which argument we use to do
                    # device guard.  This describes the precedence order.
                    self_arg = [
                        f.func.arguments.self_arg.argument
                    ] if f.func.arguments.self_arg is not None else []
                    candidate_args = itertools.chain(
                        self_arg, f.func.arguments.out,
                        f.func.arguments.flat_positional)

                    # Only tensor like arguments are eligible
                    device_of = next(
                        (f'{a.name}'
                         for a in candidate_args if a.type.is_tensor_like()),
                        None)
                    if device_of is not None:
                        device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));"

            return f"""\
namespace {{

{returns_type} {name}({args_str}) {{
  {device_guard}
  return {impl_name}({args_exprs_str});
}}

}} // anonymous namespace
"""

        elif self.target is Target.REGISTRATION:
            if f.manual_kernel_registration:
                return None
            else:
                dispatcher_sig = DispatcherSignature.from_schema(f.func)
                payload = f"TORCH_FN({name})"
                return f'm.impl("{f.func.name}",\n{payload});\n'
        else:
            assert_never(self.target)
Example #25
0
        def gen_unstructured_external(f: NativeFunction) -> Optional[str]:
            if not requires_backend_wrapper(f, self.backend_index):
                return None

            def get_device_param(args: List[Argument]) -> str:
                # TODO: the XLA codegen has specific precedence rules when determining which tensor argument
                # to use as the device argument.
                # We should update this to be consistent with how we choose device guards.
                const_tensor_or_self = [
                    a for a in args
                    if (a.type == BaseType(BaseTy.Tensor)
                        or a.type == OptionalType(BaseType(BaseTy.Tensor)))
                    and not a.is_write
                ]
                if any(const_tensor_or_self):
                    return const_tensor_or_self[0].name
                tensor_like = [a for a in args if a.type.is_tensor_like()]
                if any(tensor_like):
                    return tensor_like[0].name
                device_like = [
                    a for a in args if a.type == BaseType(BaseTy.Device)
                    or a.type == OptionalType(BaseType(BaseTy.Device))
                ]
                if any(device_like):
                    return device_like[0].name
                raise AssertionError(
                    "Need a tensor-like or device argument in order to determine the output device"
                )

            # XLA appears to have used the dispatcher convention to write their kernel signatures,
            # probably because they based their signatures off of our RegistrationDeclarations.h
            # See Note [External Backends Follow Dispatcher API]
            dispatcher_sig = DispatcherSignature.from_schema(f.func)
            name = dispatcher_sig.name()
            args = dispatcher_sig.arguments()

            if self.target is Target.NAMESPACED_DECLARATION:
                return f"  static {dispatcher_sig.decl()};"

            elif self.target is Target.REGISTRATION:
                # This codegen is only responsible for registering CPU fallback kernels
                # We also skip registrations if there is a functional backend kernel,
                # because we generate out/inplace wrappers in that case (handled in register_dispatch_key.py).
                if self.backend_index.get_kernel(f) is not None or \
                        (isinstance(g, NativeFunctionsGroup) and gets_generated_out_inplace_wrapper(f, g, self.backend_index)):
                    return ''
                payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&AtenXlaTypeDefault::{name})"
                return f'  m.impl("{f.func.name}", {payload});\n'

            if self.target is not Target.NAMESPACED_DEFINITION:
                assert_never(self.target)

            # Everything below here is where we generate the CPU fallback.
            dispatcher_order_args = dispatcher.jit_arguments(f.func)

            # Map each argument to it's intermediate variable name in the fallback
            # We have to do it separately for TensorList/Optional<Tensor>/Tensor
            tensorlist_args: Dict[Argument, str] = {
                a: f'l_{a.name}'
                for a in dispatcher_order_args if isinstance(a.type, ListType)
                and a.type.elem == BaseType(BaseTy.Tensor)
            }

            opt_tensors = [
                a for a in dispatcher_order_args
                if isinstance(a.type, OptionalType)
                and a.type.elem == BaseType(BaseTy.Tensor)
            ]
            opt_tensor_args: Dict[Argument, str] = {
                a: f'xlatens_opt[{i}]'
                for i, a in enumerate(opt_tensors)
            }

            tensors = [
                a for a in dispatcher_order_args
                if a.type == BaseType(BaseTy.Tensor)
            ]
            tensor_args: Dict[Argument, str] = {
                a: f'xlatens[{i}]'
                for i, a in enumerate(tensors)
            }
            annotated_tensor_indices: List[int] = [
                i for i, a in enumerate(tensors)
                if a.annotation is not None and a.annotation.is_write
            ]

            print_args_str = ''.join([
                f' << " {a.name}=" << {a.name}.toString()'
                for a in tensor_args.keys()
            ])

            tensorlist_intermediates_str = ''
            if len(tensorlist_args) > 0:
                tensorlist_intermediates_str = '\n'.join([
                    f'  auto {updated_name} = to_cpu({arg.name});'
                    for arg, updated_name in tensorlist_args.items()
                ])

            opt_tensor_intermediates_str = ''
            if len(opt_tensor_args) > 0:
                arg_str = ", ".join([a.name for a in opt_tensor_args.keys()])
                opt_tensor_intermediates_str = f'\n  std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};'
                opt_tensor_intermediates_str += '\n  auto xlatens_opt = to_cpu(xlatens_opt_tensors);'

            intermediates = ''
            if tensorlist_intermediates_str != '':
                intermediates += tensorlist_intermediates_str + '\n'
            intermediates += f"  std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};"
            intermediates += "\n  auto xlatens = to_cpu(xlatens_tensors);"
            if opt_tensor_intermediates_str != '':
                intermediates += opt_tensor_intermediates_str

            is_method = Variant.function not in f.variants
            func_name = f'AtenXlaTypeDefault::{name}'

            # Gather all of the updated variable names to call into the CPU operator.
            # Just use the original binding names for inputs where we didn't create explicit intermediate variables.
            updated_bindings: List[str] = [
                tensorlist_args.get(
                    a, opt_tensor_args.get(a, tensor_args.get(a, a.name)))
                for a in dispatcher_order_args
            ]

            at_call_name = CppSignatureGroup.from_native_function(
                f, method=is_method).most_faithful_signature().name()

            # Notice that we don't need to perform a translate: we're technically going from the dispatcher API
            # to the faithful C++ API, which are carefuly written to be exactly the same.
            cpu_result_name = 'x_result'
            if is_method:
                at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});'
            else:
                at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});'
            avoid_warning = ''
            if f.func.returns:
                at_call = f'auto&& {cpu_result_name} = {at_call}'
                avoid_warning = f'\n  static_cast<void>({cpu_result_name}); // Avoid warnings in case not used'

            collect_mutated_tensors = ''
            update_tensors = ''
            if len(annotated_tensor_indices) > 0:
                indices_str = ", ".join(
                    [str(i) for i in annotated_tensor_indices])
                collect_mutated_tensors = f'\n  std::vector<size_t> xlatens_update_indices = {{{indices_str}}};'
                # TODO: uncomment the resize line below. Taken out temporarily for testing
                update_tensors = '''
  for (int i : xlatens_update_indices) {
    // if (xlatens_tensors[i].sizes() != xlatens[i].sizes()) xlatens_tensors[i].resize_(xlatens[i].sizes());
    at::_copy_from_and_resize(xlatens[i], xlatens_tensors[i]);
  }
'''

            returns = ''
            if f.func.returns:
                ret_names = cpp.return_names(f, fallback_name=cpu_result_name)
                if len(ret_names) == 1:
                    returns = xla_tensor_creation_api(
                        ret_names[0],
                        f.func.returns[0],
                        get_device_param(dispatcher_order_args),
                        cpu_result_name=cpu_result_name)
                else:
                    return_args = [
                        xla_tensor_creation_api(
                            ret_names[i],
                            f.func.returns[i],
                            get_device_param(dispatcher_order_args),
                            cpu_result_name=f'std::get<{i}>({cpu_result_name})'
                        ) for i in range(len(f.func.returns))
                    ]
                    returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})'
            return_str = ''
            if returns != '':
                return_str = f'\n  return {returns};'

            return f"""\
        def gen_unstructured_external(
                f: ExternalBackendFunction) -> Optional[str]:
            if not requires_backend_wrapper(f):
                return None

            def get_device_param(args: List[Argument]) -> str:
                # TODO: the XLA codegen has specific precedence rules when determining which tensor argument
                # to use as the device argument.
                # We should update this to be consistent with how we choose device guards.
                const_tensor_or_self = [
                    a for a in args
                    if (a.type == BaseType(BaseTy.Tensor)
                        or a.type == OptionalType(BaseType(BaseTy.Tensor)))
                    and not a.is_write
                ]
                if any(const_tensor_or_self):
                    return const_tensor_or_self[0].name
                tensor_like = [a for a in args if a.type.is_tensor_like()]
                if any(tensor_like):
                    return tensor_like[0].name
                device_like = [
                    a for a in args if a.type == BaseType(BaseTy.Device)
                    or a.type == OptionalType(BaseType(BaseTy.Device))
                ]
                if any(device_like):
                    return device_like[0].name
                raise AssertionError(
                    "Need a tensor-like or device argument in order to determine the output device"
                )

            # XLA appears to have used the dispatcher convention to write their kernel signatures,
            # probably because they based their signatures off of our RegistrationDeclarations.h
            dispatcher_sig = DispatcherSignature.from_schema(
                f.native_function.func)
            name = dispatcher_sig.name()
            args = dispatcher_sig.arguments()

            if self.target is Target.NAMESPACED_DECLARATION:
                return f"  static {dispatcher_sig.decl()};"

            elif self.target is Target.REGISTRATION:
                if f.metadata is not None:
                    # xla has their own kernel: register it
                    namespace = 'AtenXlaType'
                else:
                    # xla doesn't have a kernel: register the cpu fallback (or codegen'd out kernel).
                    namespace = 'AtenXlaTypeDefault'
                payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&{namespace}::{name})"
                return f'  m.impl("{f.native_function.func.name}", {payload});\n'

            if self.target is not Target.NAMESPACED_DEFINITION:
                assert_never(self.target)

            # Instead of generating a CPU fallback, the xla codegen generates out wrappers for a few hardcoded operators.
            # TODO: we should generate out wrappers for ALL valid out kernels; not just ones in xla's hardcoded list
            if f.native_function.func.kind() is SchemaKind.out and str(f.native_function.func.name.name) in _FN_OUT \
                    and isinstance(g, ExternalBackendFunctionsGroup):
                return gen_out_wrapper(g)

            # Everything below here is where we generate the CPU fallback.
            dispatcher_order_args = dispatcher.jit_arguments(
                f.native_function.func)

            # Map each argument to it's intermediate variable name in the fallback
            # We have to do it separately for TensorList/Optional<Tensor>/Tensor
            tensorlist_args: Dict[Argument, str] = {
                a: f'l_{a.name}'
                for a in dispatcher_order_args if isinstance(a.type, ListType)
                and a.type.elem == BaseType(BaseTy.Tensor)
            }

            opt_tensors = [
                a for a in dispatcher_order_args
                if isinstance(a.type, OptionalType)
                and a.type.elem == BaseType(BaseTy.Tensor)
            ]
            opt_tensor_args: Dict[Argument, str] = {
                a: f'xlatens_opt[{i}]'
                for i, a in enumerate(opt_tensors)
            }

            tensors = [
                a for a in dispatcher_order_args
                if a.type == BaseType(BaseTy.Tensor)
            ]
            tensor_args: Dict[Argument, str] = {
                a: f'xlatens[{i}]'
                for i, a in enumerate(tensors)
            }
            annotated_tensor_indices: List[int] = [
                i for i, a in enumerate(tensors)
                if a.annotation is not None and a.annotation.is_write
            ]

            print_args_str = ''.join([
                f' << " {a.name}=" << {a.name}.toString()'
                for a in tensor_args.keys()
            ])

            tensorlist_intermediates_str = ''
            if len(tensorlist_args) > 0:
                tensorlist_intermediates_str = '\n'.join([
                    f'  auto {updated_name} = bridge::XlaCreateTensorList({arg.name});'
                    for arg, updated_name in tensorlist_args.items()
                ])

            opt_tensor_intermediates_str = ''
            if len(opt_tensor_args) > 0:
                arg_str = ", ".join([a.name for a in opt_tensor_args.keys()])
                opt_tensor_intermediates_str = f'\n  std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};'
                opt_tensor_intermediates_str += '\n  auto xlatens_opt = bridge::XlaCreateOptTensorList(xlatens_opt_tensors);'

            intermediates = ''
            if tensorlist_intermediates_str != '':
                intermediates += tensorlist_intermediates_str + '\n'
            intermediates += f"  std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};"
            intermediates += "\n  auto xlatens = bridge::XlaCreateTensorList(xlatens_tensors);"
            if opt_tensor_intermediates_str != '':
                intermediates += opt_tensor_intermediates_str

            is_method = Variant.function not in f.native_function.variants
            func_name = f'AtenXlaTypeDefault::{name}'

            # Gather all of the updated variable names to call into the CPU operator.
            # Just use the original binding names for inputs where we didn't create explicit intermediate variables.
            updated_bindings: List[str] = [
                tensorlist_args.get(
                    a, opt_tensor_args.get(a, tensor_args.get(a, a.name)))
                for a in dispatcher_order_args
            ]

            at_call_name = CppSignatureGroup.from_native_function(
                f.native_function,
                method=is_method).most_faithful_signature().name()

            # Notice that we don't need to perform a translate: we're technically going from the dispatcher API
            # to the faithful C++ API, which are carefuly written to be exactly the same.
            cpu_result_name = 'x_result'
            if is_method:
                at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});'
            else:
                at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});'
            avoid_warning = ''
            if f.native_function.func.returns:
                at_call = f'auto&& {cpu_result_name} = {at_call}'
                avoid_warning = f'\n  static_cast<void>({cpu_result_name}); // Avoid warnings in case not used'

            collect_mutated_tensors = ''
            update_tensors = ''
            if len(annotated_tensor_indices) > 0:
                indices_str = ", ".join(
                    [str(i) for i in annotated_tensor_indices])
                collect_mutated_tensors = f'\n  std::vector<size_t> xlatens_update_indices = {{{indices_str}}};'
                update_tensors = '\n  bridge::XlaUpdateTensors(xlatens_tensors, xlatens, xlatens_update_indices);'

            returns = ''
            if f.native_function.func.returns:
                ret_names = cpp.return_names(f.native_function,
                                             fallback_name=cpu_result_name)
                if len(ret_names) == 1:
                    returns = xla_tensor_creation_api(
                        ret_names[0],
                        f.native_function.func.returns[0],
                        get_device_param(dispatcher_order_args),
                        cpu_result_name=cpu_result_name)
                else:
                    return_args = [
                        xla_tensor_creation_api(
                            ret_names[i],
                            f.native_function.func.returns[i],
                            get_device_param(dispatcher_order_args),
                            cpu_result_name=f'std::get<{i}>({cpu_result_name})'
                        ) for i in range(len(f.native_function.func.returns))
                    ]
                    returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})'
            return_str = ''
            if returns != '':
                return_str = f'\n  return {returns};'

            return f"""\
Example #27
0
 def gen_declaration(self, f: NativeFunction) -> str:
     unambiguous_name = self.unambiguous_function_name(f)
     sig = DispatcherSignature.from_schema(f.func)
     return f"TORCH_API {sig.decl(unambiguous_name)};"
def emit_view_functionalization_body(f: NativeFunction,
                                     functional_op: NativeFunction) -> str:
    # view op case
    assert f.is_view_op

    if f.tag is Tag.inplace_view:
        # This op is both an inplace op AND a view op.
        # See Note [Functionalization Pass - Inplace View Ops] for details.
        # I currently have the view meta call into the out-of-place variant of the view, to avoid
        # having to define an extra ~20 inplace {view}_inverse_ functions.
        # Most view ops don't have NativeFunctionGroup's both, because we don't define out= variants for view ops.
        # I'm assuming that every inplace-view op has a corresponding out-of-place view op,
        # with the same name but the trailing underscore removed.
        # This is currently asserted at parse time in gen.py (see error_check_native_functions).
        assert f.func.kind() is SchemaKind.inplace
        # Requirement: Every inplace_view op needs to have a corresponding functional view op, which we paired together beforehand.
        assert functional_op is not None
        api_name = functional_op.func.name.unambiguous_name()
        call_sig = DispatcherSignature.from_schema(functional_op.func)
    else:
        api_name = f.func.name.unambiguous_name()
        call_sig = DispatcherSignature.from_schema(f.func)

    dispatcher_sig = DispatcherSignature.from_schema(f.func)
    assert_view_op_properties(f.func)
    view_tensor_name = dispatcher_sig.arguments()[0].name

    keyset = 'dispatchKeySet & c10::after_func_keyset'
    return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type()

    unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args(
        dispatcher_sig)
    view_redispatch_args = [keyset] + [
        e.expr for e in translate(
            unwrapped_args_ctx, call_sig.arguments(), method=False)
    ]

    forward_lambda = FunctionalizationLambda.from_func(
        f, functional_op=functional_op, is_reverse=False)
    reverse_lambda = FunctionalizationLambda.from_func(
        f, functional_op=functional_op, is_reverse=True)

    # The meta API call should use the same arguments, but convert all tensors to meta tensors first.
    meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(
        dispatcher_sig)
    meta_call_args = [
        e.expr
        for e in translate(meta_call_ctx, call_sig.arguments(), method=False)
    ]

    if f.tag is Tag.inplace_view:
        # See Note [Functionalization Pass - Inplace View Ops] for more details
        return f"""
      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        {forward_lambda.decl()} {{
          return {forward_lambda.inner_call()}
        }},
        {reverse_lambda.decl()} {{
          return {reverse_lambda.inner_call()}
        }}
      );
      at::functionalization::impl::mutate_view_meta({view_tensor_name}, view_meta);
      {unwrap_tensor_args_str}
      {return_type} reference_tensor_output;
      {{
        at::AutoDispatchSkipFunctionalize guard;
        {meta_conversion_str}
        reference_tensor_output = at::_ops::{api_name}::call({', '.join(meta_call_args)});
      }}
      // See  Note [Propagating strides in the functionalization pass]
      at::functionalization::impl::set_sizes_strides_offset({view_tensor_name}, reference_tensor_output);
      return {view_tensor_name};
"""

    else:
        return f"""