예제 #1
0
def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str]:

    if g.view_copy is None:
        return None
    # view_copy is a native signature, since we're generating an at::native:: kernel
    view_copy_sig = NativeSignature(g.view_copy.func)
    # view is a dispatcher signature, since we're calling into the at::_ops API
    view_sig = DispatcherSignature(g.view.func)

    view_api_name = g.view.func.name.unambiguous_name()
    exprs = ", ".join(
        [e.expr for e in translate(view_copy_sig.arguments(), view_sig.arguments())]
    )

    # view ops today always return either a Tensor or a list of Tensors
    assert len(g.view.func.returns) == 1
    assert g.view.func.returns[0].type == BaseType(
        BaseTy.Tensor
    ) or g.view.func.returns[0].type == ListType(BaseType(BaseTy.Tensor), None)

    if g.view.func.returns[0].type == BaseType(BaseTy.Tensor):
        return_cloned_output = """\
  return output.clone();"""
    else:
        # If the return type is a list, we need to clone each tensor in the list.
        return_cloned_output = f"""\
  {view_copy_sig.returns_type().cpp_type()} out_clone;
  for (const auto i : c10::irange(output.size())) {{
    out_clone.push_back(output[i].clone());
  }}
  return out_clone;"""

    # The default generated composite kernel for {view}_copy() operators just clones
    # the input tensor, and runs the underlying view on the clone.
    return f"""
예제 #2
0
def gen_composite_view_copy_kernel(
        g: NativeFunctionsViewGroup) -> Optional[str]:

    if g.view_copy is None:
        return None

    # For view_copy.SymInt overloads,
    # See gen_symint_view_copy_kernel.
    if g.view_copy.func.name.overload_name == "SymInt":
        return None

    # We can make view_copy work in more cases by using reshape()
    # when a normal view call would ordinarily fail.
    # This also makes LTC more efficient, because they don't need to include
    # clone() calls in their graph (which is normally needed by reshape).
    if str(g.view_copy.func.name) == "view_copy":
        return """\
at::Tensor view_copy(const at::Tensor & self, at::IntArrayRef size) {
  DimVector shape = infer_size_dv(size, self.numel());
  if (!at::detail::computeStride(self.sizes(), self.strides(), shape).has_value()) {
    return self.reshape(size);
  } else {
    auto output = at::_ops::view::call(self, size);
    return output.clone();
  }
}
"""
    # view_copy is a native signature, since we're generating an at::native:: kernel
    view_copy_sig = NativeSignature(g.view_copy.func)

    # view is a dispatcher signature, since we're calling into the at::_ops API
    view_sig = DispatcherSignature(g.view.func)

    view_api_name = g.view.func.name.unambiguous_name()
    exprs = ", ".join([
        e.expr
        for e in translate(view_copy_sig.arguments(), view_sig.arguments())
    ])

    # view ops today always return either a Tensor or a list of Tensors
    assert len(g.view.func.returns) == 1
    assert g.view.func.returns[0].type == BaseType(
        BaseTy.Tensor) or g.view.func.returns[0].type == ListType(
            BaseType(BaseTy.Tensor), None)

    if g.view.func.returns[0].type == BaseType(BaseTy.Tensor):
        return_cloned_output = """\
  return output.clone();"""
    else:
        # If the return type is a list, we need to clone each tensor in the list.
        return_cloned_output = f"""\
  {view_copy_sig.returns_type().cpp_type()} out_clone;
  for (const auto i : c10::irange(output.size())) {{
    out_clone.push_back(output[i].clone());
  }}
  return out_clone;"""

    # The default generated composite kernel for {view}_copy() operators just clones
    # the input tensor, and runs the underlying view on the clone.
    return f"""
def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> Optional[str]:
    # We should only be generating these for code-generated NativeFunctions
    if "generated" not in g.functional.tags:
        return None
    # And we always write the kernel for a generated op in terms of a non-generated op.
    if g.inplace is not None and "generated" not in g.inplace.tags:
        target_f = g.inplace
    elif g.mutable is not None and "generated" not in g.mutable.tags:
        target_f = g.mutable
    else:
        # We should be guaranteed to have a valid inplace/mutable variant to call into.
        # See Note: [Mutable Ops Not Using Functionalization]
        raise AssertionError(str(g.functional.func))

    sig = DispatcherSignature(g.functional.func)
    target_sig = DispatcherSignature(target_f.func)

    context: List[Union[Binding, Expr]] = []
    clone_mutable_inputs = []
    cloned_return_names = []
    # We can't just directly pass all of the arguments from the functional op into the mutating op.
    # We need to check for which inputs to the mutating operator are mutable,
    # and clone those inputs first.
    for a_curr, a_tgt in zip(
        dispatcher.jit_arguments(g.functional.func),
        dispatcher.jit_arguments(target_f.func),
    ):
        if a_tgt.annotation is not None and a_tgt.annotation.is_write:
            clone_mutable_inputs.append(
                f"auto {a_curr.name}_clone = clone_arg({a_curr.name});"
            )
            context.append(
                Expr(
                    expr=f"{a_curr.name}_clone",
                    type=dispatcher.argument_type(a_curr, binds=a_curr.name),
                )
            )
            # Invariant: mutable arguments on the inner mutable op are always returns on the functional op.
            cloned_return_names.append(f"{a_curr.name}_clone")
        else:
            context.append(dispatcher.argument(a_curr))
    exprs = ", ".join([e.expr for e in translate(context, target_sig.arguments())])

    out_name = "output"
    maybe_assign = f"auto {out_name} = " if len(target_f.func.returns) > 0 else ""
    inner_return_names = gather_nonaliased_inner_rets(target_f.func, out_name)
    ret_str = return_str(
        g.functional.func.returns, inner_return_names + cloned_return_names
    )

    clone_mutable_inputs_str = "\n".join(clone_mutable_inputs)
    return f"""
예제 #4
0
def gen_symint_view_copy_kernel(view_copy: NativeFunction,
                                view_copy_symint: NativeFunction) -> str:
    # view_copy.symint is a native signature, since we're generating an at::native:: kernel
    view_copy_symint_sig = NativeSignature(view_copy_symint.func)

    # view_copy is a dispatcher signature, since we're calling into the at::_ops API
    view_copy_sig = DispatcherSignature(view_copy.func)

    exprs = ", ".join([
        e.expr for e in translate(view_copy_symint_sig.arguments(),
                                  view_copy_sig.arguments())
    ])

    return f"""
예제 #5
0
def gen_vmap_plumbing(native_function: NativeFunction) -> Optional[str]:
    schema = native_function.func
    sig = DispatcherSignature.from_schema(schema)
    returns = schema.returns

    # Only support cases where all returns are Tensors or vector<Tensor>
    if not accepts_at_least_one_tensor_input(schema):
        return None
    if len(returns) == 0:
        return gen_vmap_plumbing_no_returns(native_function)
    if not all(ret.type.is_tensor_like() for ret in returns):
        return None
    # in-place views need special handling
    if "inplace_view" in native_function.tags:
        return None

    if schema.kind() == SchemaKind.inplace:
        return gen_vmap_inplace_plumbing(native_function)

    # Don't support these (mutable, out, scratch)
    if schema.kind() != SchemaKind.functional:
        return None

    results_var = "results"
    cur_level_var = "cur_level"

    unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all,
                                              cur_level_var)
    bdims_all_none_case = gen_case_where_all_bdims_are_none(
        schema, cur_level_var)

    wrapped_returns = gen_returns(returns, cur_level_var, results_var)
    return f"""\
예제 #6
0
def unwrap_tensor_args(
    sig: DispatcherSignature, *, is_view_op: bool
) -> Tuple[str, List[Binding]]:
    context: List[Binding] = []
    unwrapped_tensor_args: List[str] = []
    for arg in sig.arguments():
        if is_tensor_like(arg.argument):
            # for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
            unwrapped_name = f"{arg.name}_"
            # For most ops, the functionalization needs to sync any pending updates on the input tensors
            # before calling the operator, since otherwise the operator will act on stale data.
            # For view ops though, we can continue to defer syncing until the tensor is used by
            # a non-view operator.
            maybe_sync_input = (
                "" if is_view_op else f"at::functionalization::impl::sync({arg.name});"
            )
            unwrapped_type, conversion_fn = get_owning_type(
                arg.nctype.remove_const_ref().type
            )
            unwrapped_tensor_args.append(
                f"""
      {unwrapped_type.cpp_type()} {unwrapped_name};
      if (at::functionalization::impl::isFunctionalTensor({arg.name})) {{
        {maybe_sync_input}
        {unwrapped_name} = at::functionalization::impl::from_functional_tensor({arg.name});
      }} else {{
        {unwrapped_name} = {conversion_fn(arg.name)};
      }}"""
            )
            context.append(arg.with_name(unwrapped_name))
        else:
            # for non-tensor inputs, we want to pass them directly into the redispatch calls.
            context.append(arg)
    unwrap_tensor_args_str = "\n      ".join(unwrapped_tensor_args)
    return unwrap_tensor_args_str, context
예제 #7
0
 def wrapper_kernel_sig(
     self, f: NativeFunction
 ) -> Union[NativeSignature, DispatcherSignature]:
     # The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names.
     return DispatcherSignature.from_schema(
         f.func, prefix=f"wrapper_{f.func.name.overload_name}_"
     )
예제 #8
0
def emit_trace_body(f: NativeFunction) -> List[str]:
    trace_body: List[str] = []

    trace_body.append(format_prerecord_trace(f))
    trace_body.append(declare_returned_variables(f))

    dispatcher_sig = DispatcherSignature.from_schema(f.func)
    dispatcher_exprs = dispatcher_sig.exprs()

    # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance.
    # See Note [Plumbing Keys Through The Dispatcher] for details.
    dispatch_key_set = "ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Tracer)"
    redispatch_args = ", ".join([dispatch_key_set] +
                                [a.expr for a in dispatcher_exprs])

    assign_return_values = (f"{tie_return_values(f)} = "
                            if f.func.kind() == SchemaKind.functional
                            and f.func.returns else "")

    # Note that this calls the slow, dispatching variants of manual_cpp_binding ops.
    # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal.
    trace_body.append(
        TRACE_DISPATCH.substitute(
            assign_return_values=assign_return_values,
            unambiguous_name=f.func.name.unambiguous_name(),
            unpacked_args=redispatch_args,
        ))

    trace_body.append(format_postrecord_trace(f))
    if f.func.returns:
        trace_body.append(f"return {get_return_value(f)};")
    return trace_body
def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str:
    schema = native_function.func
    sig = DispatcherSignature.from_schema(schema)
    cur_level_var = 'cur_level'

    unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all,
                                              cur_level_var)
    bdims_all_none_case = gen_case_where_all_bdims_are_none(
        schema, cur_level_var)

    return f"""\
def gen_case_where_all_bdims_are_none(schema, cur_level_var) -> str:
    conditions = []
    flat_args = schema.arguments.flat_all
    for arg in flat_args:
        if not arg.type.is_tensor_like():
            continue
        conditions.append(f'!isBatchedAtLevel({arg.name}, {cur_level_var})')

    sig = DispatcherSignature.from_schema(schema)
    translated_args = ', '.join(
        e.expr for e in translate(sig.arguments(), sig.arguments()))
    return f"""\
예제 #11
0
def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]:
    context: List[Binding] = []
    unwrapped_tensor_args: List[str] = []
    for arg in sig.arguments():
        if is_tensor_like(arg.argument):
            # for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
            a_ = arg.name
            unwrapped_name = f"{arg.name}_meta"
            unwrapped_tensor_args.append(f"auto {unwrapped_name} = to_meta({a_});")
            context.append(arg.with_name(unwrapped_name))
        else:
            # for non-tensor inputs, we want to pass them directly into the redispatch calls.
            context.append(arg)
    unwrap_tensor_args_str = "\n        ".join(unwrapped_tensor_args)
    return unwrap_tensor_args_str, context
예제 #12
0
def convert_to_meta_tensors(
        sig: DispatcherSignature) -> Tuple[str, List[Binding]]:
    context: List[Binding] = []
    unwrapped_tensor_args: List[str] = []
    for arg in sig.arguments():
        if isinstance(arg.argument,
                      Argument) and arg.argument.type.is_tensor_like():
            unwrapped_name = f"{arg.name}_meta"
            unwrapped_tensor_args.append(
                f"auto {unwrapped_name} = to_meta({arg.name});")
            context.append(arg.with_name(unwrapped_name))
        else:
            context.append(arg)
    unwrap_tensor_args_str = "\n        ".join(unwrapped_tensor_args)
    return unwrap_tensor_args_str, context
예제 #13
0
 def emit_registration_helper(f: NativeFunction) -> str:
     if f.has_composite_implicit_autograd_kernel:
         metadata = composite_implicit_autograd_index.get_kernel(f)
         assert metadata is not None
         native_api_name = metadata.kernel
         sig = DispatcherSignature.from_schema(f.func)
         # Note [Composite view ops in the functionalization pass]
         # We don't need to worry about implemententing functionalization kernels for views with
         # CompositeImplicitAutograd kernels, because we can just decompose them into their base operators.
         # We can't just opt the entire Functionalization dispatch key into the composite keyset though,
         # because we don't want to decompose non-view ops that are composite, like `at::ones`.
         registration_str = (
             f"static_cast<{sig.ptr_type()}>(at::native::{native_api_name})"
         )
     else:
         # non-composite view ops (and inplace ops) get a normal registration.
         registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
     return f'm.impl("{f.func.name}", {registration_str});'
예제 #14
0
def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]:
    f = fn.func
    inplace_view_body: List[str] = []

    dispatcher_sig = DispatcherSignature.from_schema(f.func)
    dispatcher_exprs = dispatcher_sig.exprs()

    # code-generated ADInplaceOrView kernels plumb and recompute dispatch keys directly through the kernel for performance.
    # See Note [Plumbing Keys Through The Dispatcher] for details.
    dispatch_key_set = "ks & c10::after_ADInplaceOrView_keyset"
    redispatch_args = ", ".join([dispatch_key_set] + [a.expr for a in dispatcher_exprs])

    # Note that this calls the slow, dispatching variants of manual_cpp_binding ops.
    # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal.
    if modifies_arguments(f):  # inplace op
        inplace_view_body.append(
            INPLACE_REDISPATCH.substitute(
                unambiguous_name=f.func.name.unambiguous_name(),
                unpacked_args=redispatch_args,
            )
        )
        for r in cpp.return_names(f):
            inplace_view_body.append(f"increment_version({r});")
    else:
        assert get_view_info(f) is not None
        inplace_view_body.append(
            VIEW_REDISPATCH.substitute(
                assign_return_values="auto " + TMP_VAR + " = ",
                unambiguous_name=f.func.name.unambiguous_name(),
                unpacked_args=redispatch_args,
            )
        )
        call, rhs_value = emit_view_body(fn, TMP_VAR)
        inplace_view_body.append(call)
        assert rhs_value is not None
        inplace_view_body.append(
            ASSIGN_RETURN_VALUE.substitute(
                return_values=tie_return_values(f), rhs_value=rhs_value
            )
        )
    if f.func.returns:
        inplace_view_body.append(f"return {get_return_value(f)};")
    return inplace_view_body
def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]:
    context: List[Binding] = []
    unwrapped_tensor_args: List[str] = []
    for arg in sig.arguments():
        if is_tensor_like(arg.argument):
            # for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
            # for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
            a_ = arg.name
            unwrapped_name = f"{arg.name}_meta"
            unwrapped_tensor_args.append(
                f"auto {unwrapped_name} = at::native::empty_strided_meta({a_}.sizes(), {a_}.strides(), \
/*dtype=*/c10::make_optional({a_}.scalar_type()), /*layout=*/c10::make_optional({a_}.layout()), \
/*device=*/c10::make_optional(c10::Device(kMeta)), /*pin_memory=*/c10::nullopt);"
            )
            context.append(arg.with_name(unwrapped_name))
        else:
            # for non-tensor inputs, we want to pass them directly into the redispatch calls.
            context.append(arg)
    unwrap_tensor_args_str = "\n        ".join(unwrapped_tensor_args)
    return unwrap_tensor_args_str, context
def gen_composite_out_kernel(g: NativeFunctionsGroup) -> Optional[str]:
    # We should only be generating these for code-generated NativeFunctions
    if "generated" not in g.out.tags:
        return None
    # And we always write the kernel for the out= op in terms of the functional.
    # Note that the functional op might have also been generated, but we don't have to
    # worry about cycles, because the generated functional kernels are always implemented
    # in terms of non-generated kernels (see gen_composite_functional_kernel).

    sig = DispatcherSignature(g.out.func)
    target_sig = DispatcherSignature(g.functional.func)

    exprs = ", ".join(
        [e.expr for e in translate(sig.arguments(), target_sig.arguments())]
    )

    copy_outs = []
    out_name = "tmp_output"
    for i, out_arg in enumerate(g.out.func.arguments.out):
        functional_return_name = (
            out_name
            if len(g.functional.func.returns) == 1
            else f"std::get<{i}>({out_name})"
        )
        copy_outs.append(
            f"""\
  resize_out_helper({out_arg.name}, {functional_return_name});
  copy_arg({out_arg.name}, {functional_return_name});"""
        )

    rets = []
    # For each return arg in the calling (out=) operator,
    # If it corresponds to an aliased input, return the input.
    # Otherwise, return the corresponding output from calling the functional operator.
    for i, ret_name in enumerate(g.out.func.aliased_return_names()):
        if ret_name is not None:
            rets.append(ret_name)
        else:
            functional_return_name = (
                out_name
                if len(g.functional.func.returns) == 1
                else f"std::get<{i}>({out_name})"
            )
            rets.append(functional_return_name)

    copy_outs_str = "\n".join(copy_outs)

    # Kernel name needs to follow the naming convention defined in `generate_function()`
    return f"""
예제 #17
0
def gen_vmap_inplace_plumbing(
        native_function: NativeFunction) -> Optional[str]:
    # Assumptions:
    # - only one argument is being modified in-place
    # - the argument that is being modified in-place is the first argument
    # - all returns are either Tensor, tuple of Tensor, or TensorList
    schema = native_function.func
    sig = DispatcherSignature.from_schema(schema)
    returns = schema.returns

    # Check assumptions. If these are invalid we return None
    # and punt the work to handle them to the future.
    assert schema.kind() == SchemaKind.inplace
    if not is_mutated_arg(schema.arguments.flat_all[0]):
        return None
    if not len(
        [arg
         for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) == 1:
        return None

    # Only support cases where all returns are Tensors or vector<Tensor>
    if len(returns) == 0:
        return None
    if not all(
            is_tensor(ret.type) or is_tensor_list(ret.type)
            for ret in returns):
        return None
    if not accepts_at_least_one_tensor_input(schema):
        return None

    cur_level_var = "cur_level"

    unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all,
                                              cur_level_var)
    bdims_all_none_case = gen_case_where_all_bdims_are_none(
        schema, cur_level_var)

    return f"""\
def gen_vmap_plumbing(native_function: NativeFunction) -> str:
    schema = native_function.func
    sig = DispatcherSignature.from_schema(schema)
    returns = schema.returns

    # Only support cases where all returns are Tensors or vector<Tensor>
    if len(returns) == 0:
        return gen_vmap_plumbing_no_returns(native_function)
    if not all(ret.type.is_tensor_like() for ret in returns):
        return None
    if not accepts_at_least_one_tensor_input(schema):
        return None
    # in-place views need special handling
    if native_function.tag == Tag.inplace_view:
        return None

    if schema.kind() == SchemaKind.inplace:
        return gen_vmap_inplace_plumbing(native_function)

    # Don't support these
    if schema.kind() == SchemaKind.out:
        return None

    # From now on, assume we're dealing with a functional (out-of-place) operation
    assert schema.kind() == SchemaKind.functional

    results_var = 'results'
    cur_level_var = 'cur_level'

    unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all,
                                              cur_level_var)
    bdims_all_none_case = gen_case_where_all_bdims_are_none(
        schema, cur_level_var)

    wrapped_returns = gen_returns(returns, cur_level_var, results_var)
    return f"""\
예제 #19
0
    def shape_inference(self, func: NativeFunction,
                        schema: LazyIrSchema) -> str:
        metadata = self.backend_index.get_kernel(func)
        assert metadata is not None
        all_args = schema.filtered_args()
        returns_length = len(schema.returns)
        # call the meta kernel if it exists, to compute output shape/dtype for our IR
        # Note [Generated LTC Shape Functions]
        # LTC uses meta tensors from core to do shape inference when possible, and otherwise
        # we generate a shape function declaration that needs to be manually implemented.
        # How do we detect which ops are eligible to use meta tensors?
        # In general we should be able to use meta tensors not just on structured operators,
        # but also on composite operators that are implemented in terms of structured kernels.
        # We don't currently have a way of knowing at codegen time which ops are implemented that way.
        # This is the case for all view and view_copy operators however, so we're going to
        # use them specifically for all of the view_copy ops (instead of manually writing shape rules for all of them).
        is_view_copy_op = "view_copy" in func.tags
        is_structured = func.structured or func.structured_delegate is not None
        if is_structured or is_view_copy_op:
            meta_out = """
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};"""
            if returns_length > 1:

                def this_shape(i: int) -> str:
                    return f"torch::lazy::Shape(std::get<{i}>(out_meta).scalar_type(), std::get<{i}>(out_meta).sizes().vec())"

                shapes_str = ",".join(
                    [this_shape(i) for i in range(returns_length)])
                meta_out = "std::vector<torch::lazy::Shape> shapes{" + shapes_str + "};"

            # Convert tensor args to the meta device and call it.
            # (We can't pass in the input tensors directly, because they are "functional wrappers".
            # If any of the meta kernels call a tensor op and redispatch, we don't want to hit the functionalize kernels.)
            # Even at::meta:: functions might redispatch, e.g. if they call into view ops.
            dispatcher_sig = DispatcherSignature.from_schema(func.func)
            meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(
                dispatcher_sig)
            meta_call_args = [
                e.expr for e in translate(
                    meta_call_ctx, dispatcher_sig.arguments(), method=False)
            ]
            if is_view_copy_op:
                # view_copy ops always have a CompositeExplicitAutogradNonFunctional kernel
                assert func.has_composite_explicit_autograd_non_functional_kernel
                dispatch_ns = "compositeexplicitautogradnonfunctional"
            else:
                dispatch_ns = "meta"
            aten_name = schema.aten_name
            # TODO: this is trolling
            if func.func.has_symint():
                aten_name += "_symint"
            shape_str = f"""\
        {meta_conversion_str}
        auto out_meta = at::{dispatch_ns}::{aten_name}({', '.join(meta_call_args)});
        {meta_out}"""
        else:
            shape_sig = ComputeShapeSignature(metadata.kernel, func)
            shape_str = f"""
            auto shapes = {shape_sig.shape_call};"""

        shape_str += f"""
            TORCH_INTERNAL_ASSERT(shapes.size() == {returns_length});"""

        # Calculating which dimensions are symbolic
        func_schema_str = "aten::" + str(func.func)
        shape_str += f"""
            if(torch::lazy::symbolicShapeEnabled()){{
                std::vector<torch::jit::IValue> inputs = {{ {', '.join(str(a.name) for a in all_args)} }};
                const char* schema_str = "{func_schema_str}";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }}
        """
        return shape_str
def emit_inplace_functionalization_body(
        f: NativeFunction, functional_op: Optional[NativeFunction]) -> str:
    # mutation case
    assert modifies_arguments(f)

    dispatcher_sig = DispatcherSignature.from_schema(f.func)

    return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type()

    unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args(
        dispatcher_sig, is_view_op=False)

    mutated_names = [
        a.name for a in f.func.arguments.flat_all
        if a.type.is_tensor_like() and a.annotation is not None
    ]
    non_mutated_names = [
        a.name for a in f.func.arguments.flat_all
        if a.type.is_tensor_like() and a.annotation is None
    ]
    # all mutable inputs must be functional tensors in order to participate in functionalization
    check_all_mutated_args_are_functional = " && ".join(["true"] + [
        f"at::functionalization::impl::isFunctionalTensor({a})"
        for a in mutated_names
    ])
    check_any_non_mutated_args_are_functional = " || ".join(["false"] + [
        f"at::functionalization::impl::isFunctionalTensor({a})"
        for a in non_mutated_names
    ])
    # These are used in the cases where we don't functionalize and redispatch to the inplace op
    # case 1: we hit an inplace op that doesn't have an out-of-place equivalent
    # case 2: we hit an inplace ops but our inputs are not functional tensors (in which case our kernel just no-ops)
    inplace_exprs = [
        e.expr for e in translate(
            unwrapped_args_ctx, dispatcher_sig.arguments(), method=False)
    ]

    if functional_op is None:
        # We can't functionalize this inplace op, since we don't know what the corresponding functional op is.
        warn_str = f"""Note: the functionalization pass encountered an operator ({str(f.func.name)}) that it could not \
functionalize, because it couldn't find an out-of-place equivalent of the operator to call. \
Instead, it's calling the inplace/view operator directly. \
If this causes problems in your program, consider upstreaming the out-of-place op to PyTorch."""

        return f"""
    {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
      if (c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {{
          TORCH_WARN("{warn_str}");
      }}
      {unwrap_tensor_args_str}
      at::AutoDispatchSkipFunctionalize guard;
      // Redispatch as normally otherwise, since XLA has its own lowerings for special inplace ops.
      at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(inplace_exprs)});
      {return_str(f)};
    }}
"""
    else:
        # call the out-of-place variant of the op
        functional_sig = DispatcherSignature.from_schema(functional_op.func)
        functional_exprs = [
            e.expr for e in translate(
                unwrapped_args_ctx, functional_sig.arguments(), method=False)
        ]

    if f.func.is_out_fn():
        mutable_input_post_processing = "\n".join([
            f"""
      at::functionalization::impl::replace_(
        {a.name}, {'std::get<' + str(i) + '>(tmp_output)' if len(f.func.returns) > 1 else 'tmp_output'});
      at::functionalization::impl::commit_update({a.name});"""
            for (i, a) in enumerate(f.func.arguments.out) if a.annotation
            and a.annotation.is_write and a.type.is_tensor_like()
        ])
    else:
        mutable_input_post_processing = "\n".join([
            f"""
      at::functionalization::impl::replace_({a.name}, tmp_output);
      at::functionalization::impl::commit_update({a.name});"""
            for a in f.func.arguments.flat_all if a.annotation
            and a.annotation.is_write and a.type.is_tensor_like()
        ])

    return f"""
예제 #21
0
def emit_inplace_functionalization_body(
    f: NativeFunction, g: NativeFunctionsGroup
) -> str:
    # mutation case
    assert modifies_arguments(f)

    dispatcher_sig = DispatcherSignature.from_schema(f.func)

    unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args(
        dispatcher_sig, is_view_op=False
    )

    mutated_names = [
        a.name
        for a in f.func.arguments.flat_all
        if a.type.is_tensor_like() and a.annotation is not None
    ]
    non_mutated_names = [
        a.name
        for a in f.func.arguments.flat_all
        if a.type.is_tensor_like() and a.annotation is None
    ]
    # all mutable inputs must be functional tensors in order to participate in functionalization
    check_all_mutated_args_are_functional = " && ".join(
        ["true"]
        + [
            f"at::functionalization::impl::isFunctionalTensor({a})"
            for a in mutated_names
        ]
    )
    check_any_non_mutated_args_are_functional = " || ".join(
        ["false"]
        + [
            f"at::functionalization::impl::isFunctionalTensor({a})"
            for a in non_mutated_names
        ]
    )
    # These are used in the cases where we don't functionalize and redispatch to the inplace op
    # case 1: we hit an inplace op that doesn't have an out-of-place equivalent
    # case 2: we hit an inplace ops but our inputs are not functional tensors (in which case our kernel just no-ops)
    inplace_exprs = [
        e.expr
        for e in translate(unwrapped_args_ctx, dispatcher_sig.arguments(), method=False)
    ]

    # call the out-of-place variant of the op
    return_type = (
        dispatcher.returns_type(g.functional.func.returns).remove_const_ref().cpp_type()
    )
    functional_sig = DispatcherSignature.from_schema(g.functional.func)
    functional_exprs = [
        e.expr
        for e in translate(unwrapped_args_ctx, functional_sig.arguments(), method=False)
    ]

    if f.func.is_out_fn():
        mutable_input_post_processing = "\n".join(
            [
                f"""
      at::functionalization::impl::replace_(
        {a.name}, {'std::get<' + str(i) + '>(tmp_output)' if len(f.func.returns) > 1 else 'tmp_output'});
      at::functionalization::impl::commit_update({a.name});"""
                for (i, a) in enumerate(f.func.arguments.out)
                if a.annotation and a.annotation.is_write and a.type.is_tensor_like()
            ]
        )
    else:
        mutable_input_post_processing = "\n".join(
            [
                f"""
      at::functionalization::impl::replace_({a.name}, tmp_output);
      at::functionalization::impl::commit_update({a.name});"""
                for a in f.func.arguments.flat_all
                if a.annotation and a.annotation.is_write and a.type.is_tensor_like()
            ]
        )

    meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)

    return f"""
예제 #22
0
def emit_view_functionalization_body(
    g: NativeFunctionsViewGroup, *, view_inplace: bool
) -> str:
    if view_inplace:
        # This op is both an inplace op AND a view op.
        # See Note [Functionalization Pass - Inplace View Ops] for details.
        # I currently have the view meta call into the out-of-place variant of the view, to avoid
        # having to define an extra ~20 inplace {view}_inverse_ functions.
        # Most view ops don't have NativeFunctionGroup's both, because we don't define out= variants for view ops.
        # I'm assuming that every inplace-view op has a corresponding out-of-place view op,
        # with the same name but the trailing underscore removed.
        # This is currently asserted at parse time in gen.py (see error_check_native_functions).
        assert g.view_inplace is not None
        f = g.view_inplace
    else:
        f = g.view

    assert g.view_copy is not None
    with native_function_manager(f):
        call_sig = DispatcherSignature.from_schema(g.view_copy.func)

        # the "view_copy" op name that the functionalization kernels need to call
        api_name = g.view_copy.func.name.unambiguous_name()
        # Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors)
        # "no-op"ing in this context is just redispatching to the original op.
        noop_api_name = f.func.name.unambiguous_name()

        dispatcher_sig = DispatcherSignature.from_schema(f.func)
        assert_view_op_properties(f.func)
        view_tensor_name = dispatcher_sig.arguments()[0].name

        return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type()

        unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args(
            dispatcher_sig, is_view_op=True
        )
        view_redispatch_args = [
            e.expr
            for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False)
        ]

        forward_lambda = FunctionalizationLambda.from_func(g, is_reverse=False)
        reverse_lambda = FunctionalizationLambda.from_func(g, is_reverse=True)

        # The meta API call should use the same arguments, but convert all tensors to meta tensors first.
        meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
        meta_call_args = [
            e.expr for e in translate(meta_call_ctx, call_sig.arguments(), method=False)
        ]

        if "inplace_view" in f.tags:
            # See Note [Functionalization Pass - Inplace View Ops] for more details
            return f"""
    {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
      if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{
        // functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper.
        {unwrap_tensor_args_str}
        at::AutoDispatchSkipFunctionalize guard;
        return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)});
      }}
      auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        {forward_lambda.decl()} {{
          if (reapply_views) {{
            return {forward_lambda.inner_call(reapply_views=True)}
          }} else {{
            return {forward_lambda.inner_call(reapply_views=False)}
          }}
        }},
        {reverse_lambda.decl()} {{
          return {reverse_lambda.inner_call()}
        }}
      );
      {return_type} reference_tensor_output;
      {{
        at::AutoDispatchSkipFunctionalize guard;
        {meta_conversion_str}
        reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)});
      }}
      // This function adds the above view meta to the current tensor and replays them off the base,
      // mutating the size/stride info of the current FunctionalTensorWrapper.
      // Because of this, we need to make sure to run the reference shape function above,
      // BEFORE doing this (otherwise we'll end up runnin the reference function using the wrong sizes/strides)
      at::functionalization::impl::mutate_view_meta({view_tensor_name}, view_meta);
      // See  Note [Propagating strides in the functionalization pass]
      // XLA/LTC don't implement the logic to propagate strides correctly, so we need to rely
      // on a reference implementation here (instead of relying on the output from the forward lambda
      // having the correct stride info)
      at::functionalization::impl::set_sizes_strides_offset({view_tensor_name}, reference_tensor_output);
      return {view_tensor_name};
    }}
"""

        else:
            return f"""
예제 #23
0
 def create_decl(f: NativeFunction) -> str:
     with native_function_manager(f):
         return DispatcherSignature.from_schema(f.func).decl()