Ejemplo n.º 1
0
def translate(
    bindings: Sequence[Union[Expr, Binding]],
    goals: Sequence[Union[NamedCType, Binding]],
    *, method: bool = False,
    allow_expensive_conversions: bool = False
) -> List[Expr]:

    binding_exprs: List[Expr] = []
    for b in bindings:
        if isinstance(b, Binding):
            binding_exprs.append(Expr(
                expr=b.name,
                type=b.nctype,
            ))
        else:
            binding_exprs.append(b)

    goal_ctypes: List[NamedCType] = []
    for g in goals:
        if isinstance(g, Binding):
            goal_ctypes.append(g.nctype)
        else:
            goal_ctypes.append(g)

    # Add all the bindings to the context
    ctx: Dict[NamedCType, str] = {}
    for b in binding_exprs:
        ctx[b.type] = b.expr

        # While we're at it, do some simple forward inference, looking through
        # constructors.
        #
        # NB: When should you do forward inference versus backward inference?
        # The general idea:
        #
        #   - Backward inference WHEN the goal gets smaller
        #   - Forward inference WHEN the hypothesis gets smaller
        #
        # This helps ensure termination: backward inference starts with a goal
        # and tries to make it simpler and simpler until it's trivial; if the
        # goal can grow in size, we blow up to a really huge goal size.
        # Similarly, with forward inference we take hypotheses and decompose
        # them into simpler hypotheses; if hypotheses could expand in size,
        # we also have potential nontermination.  (In the code below, forward
        # inference is only ever carried out at a single step, but you could
        # imagine repeated application of forward inference being profitable.)
        #
        # A good starting point in the literature for exploring more about proof
        # search are these lecture notes
        # https://www.cs.cmu.edu/~fp/courses/oregon-m10/04-focusing.pdf
        #
        # TODO: My kingdom for a pattern matcher
        # https://www.python.org/dev/peps/pep-0634/
        #
        # TODO: This could get us in recomputation trouble if b.expr is nontrivial.
        # Fix this by implementing some sort of sharing so that if multiple
        # goals share the same expression, we only compute it once.  This seems
        # to matter in practice as compiler is often unwilling to CSE nontrivial
        # expressions like scalar.to<scalar_t>()
        t = b.type
        if isinstance(t, ConstRefCType) and isinstance(t.elem, OptionalCType) and \
                isinstance(t.elem.elem, BaseCType) and str(t.elem.elem.type) == 'at::Tensor':
            ctx[NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))] = \
                f'({b.expr}.has_value() ? *{b.expr} : at::Tensor())'

        if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))):
            ctx[NamedCType(t.name, BaseCType(optionalTensorRefT))] = \
                f'(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())'

        if t.type == ConstRefCType(BaseCType(scalarT)):
            ctx[NamedCType(t.name, BaseCType(opmath_t))] = f'({b.expr}).to<opmath_t>()'

        if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))):
            ctx[NamedCType(t.name, BaseCType(optionalScalarRefT))] = \
                f'({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())'

        if t.type == BaseCType(scalar_t):
            ctx[NamedCType(t.name, BaseCType(opmath_t))] = f'static_cast<opmath_t>({b.expr})'

    # Add implicit bindings if the generated code is inside a Tensor method
    if method:
        ctx[NamedCType("self", MutRefCType(BaseCType(tensorT)))] = "const_cast<Tensor&>(*this)"
        ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "const_cast<Tensor&>(*this)"
        # This is better!  Byte-for-byte compat
        # ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this"

    def unsat(goal: NamedCType) -> NoReturn:
        ctx_desc = '\n'.join(f"  {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items())
        raise UnsatError(f'''
Failed to synthesize the expression "{goal.cpp_type()} {goal.name}".
When I failed, the following bindings were available in the context:

{ctx_desc}

This probably means there is a missing rule in the rules of tools.codegen.api.translate.
Check this module for more information.
''')

    # A shitty backtracking search implementation.  It's shitty because it
    # does backtracking via stack (bad idea!) and for the most part tries to
    # avoid backtracking.  In particular, if
    # direct=True, we won't try to do any fancy synthesis, just trivial
    # conversions (e.g., "T a" is OK for "const T& a").  So all of the
    # existing rules in this function simply try to solve immediately,
    # and bail if things don't work out.
    def solve(goal: NamedCType, *, direct: bool) -> str:
        def direct_solve(goal: NamedCType) -> str:
            return solve(goal, direct=True)

        if goal in ctx:
            # Trivial
            return ctx[goal]

        # const & is satisfied with mutable &
        if isinstance(goal.type, ConstRefCType):
            try:
                # WARNING: not strictly decreasing; be careful not
                # to add a direct conversion that goes satisfies
                # mutable& with const&
                return solve(NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct)
            except UnsatError:
                pass

        # mutable & is satisfied with value
        if isinstance(goal.type, MutRefCType):
            try:
                return solve(NamedCType(goal.name, goal.type.elem), direct=direct)
            except UnsatError:
                pass

        if direct:
            unsat(goal)

        # For now, all of these rules are mutually exclusive.
        if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))):
            memory_format = direct_solve(
                NamedCType(SpecialArgName.possibly_redundant_memory_format, OptionalCType(BaseCType(memoryFormatT)))
            )
            # No need to join "memory_format" and "options" if the target API takes "options" directly.
            # Otherwise it will cause the redundant memory_format error.
            if options_ctype in goal_ctypes:
                return memory_format
            try:
                options = direct_solve(options_ctype)
                return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})"
            except UnsatError:
                return memory_format
        elif goal == NamedCType("options", BaseCType(tensorOptionsT)):
            dtype = direct_solve(NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))))
            pin_memory = direct_solve(NamedCType("pin_memory", OptionalCType(BaseCType(boolT))))
            device = direct_solve(NamedCType("device", OptionalCType(BaseCType(deviceT))))
            layout = direct_solve(NamedCType("layout", OptionalCType(BaseCType(layoutT))))
            return f'TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})'

        elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))):
            options = direct_solve(options_ctype)
            return f'optTypeMetaToScalarType({options}.dtype_opt())'

        elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))):
            options = direct_solve(options_ctype)
            return f'{options}.layout_opt()'

        elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))):
            options = direct_solve(options_ctype)
            return f'{options}.device_opt()'

        elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))):
            options = direct_solve(options_ctype)
            return f'{options}.pinned_memory_opt()'

        # We can always do translations from value types to reference types, like vector<int> -> IntArrayRef
        elif goal.type == BaseCType(intArrayRefT):
            return direct_solve(NamedCType(goal.name, longVec_ctype))
        elif goal.type == BaseCType(optionalIntArrayRefT):
            return direct_solve(NamedCType(goal.name, optionalLongVec_ctype))
        elif goal.type == BaseCType(optionalScalarRefT):
            return direct_solve(NamedCType(goal.name, optionalScalar_ctype))
        elif goal.type == BaseCType(optionalTensorRefT):
            return direct_solve(NamedCType(goal.name, optionalTensor_ctype))


        # Note [translation from C++ reference to value types]
        # The below cases are all for when we have an argument with a reference type,
        # and a corresponding goal with a value type.
        # These are needed when we populate the inputs to a lambda capture and we need
        # to guarantee the lifetime of each captured argument.
        # We guard it with an explicit kwarg because converting to a value type is expensive
        # (O(n)) to convert from IntArrayRef to vector<int>),
        # so the caller of translate() should be explicit that they need it.
        if allow_expensive_conversions:
            if goal.type == VectorCType(BaseCType(longT)):
                intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT))
                argname = direct_solve(intArrayRef_ctype)
                return f'{argname}.vec()'
            elif goal.type == OptionalCType(VectorCType(BaseCType(longT))):
                optionalIntArrayRef_ctype = NamedCType(goal.name, BaseCType(optionalIntArrayRefT))
                argname = direct_solve(optionalIntArrayRef_ctype)
                return f'{argname}.has_value() ? c10::make_optional({argname}->vec()) : c10::nullopt'
            elif goal.type == OptionalCType(BaseCType(scalarT)):
                optionalScalarRef_ctype = NamedCType(goal.name, BaseCType(optionalScalarRefT))
                argname = direct_solve(optionalScalarRef_ctype)
                return f'{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt'
            elif goal.type == OptionalCType(BaseCType(scalarT)):
                optionalTensorRef_ctype = NamedCType(goal.name, BaseCType(optionalTensorRefT))
                argname = direct_solve(optionalTensorRef_ctype)
                return f'{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt'
            # Technically, we also need to handle cases of C++ containers holding reference types.
            # But there currently aren't any ops that require lambda capture codegen
            # With arguments like std::vector<IntArrayRef>.
            # If that changes, we'll have to add the translation here.

        unsat(goal)

    return [Expr(solve(g, direct=False), g) for g in goal_ctypes]
Ejemplo n.º 2
0
def saved_variables(
    formula: str,
    nctypes: List[NamedCType],
    var_names: Tuple[str, ...],
) -> Tuple[str, Tuple[SavedAttribute, ...]]:
    def stride_expr(name: str) -> str:
        assert var_names == (name, ), (
            'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor '
            'that ".strides()" is being called on.')
        return f'strides_or_error({name}, "{name}")'

    REPLACEMENTS: List[Tuple[str, Dict[str, Any]]] = [
        # replace self.sizes() with self_sizes
        (r'{}.sizes\(\)', {
            'suffix': '_sizes',
            'nctype': lambda name: NamedCType(name, BaseCType(intArrayRefT)),
        }),
        # replace self.options() with self_options
        (r'{}.options\(\)', {
            'suffix': '_options',
            'nctype': lambda name: NamedCType(name, BaseCType(tensorOptionsT)),
        }),
        # replace zeros_like(self) with self_info
        (
            r'zeros_like\({}\)',
            {
                'suffix': '_info',
                'nctype':
                lambda name: NamedCType(name, BaseCType(typeAndSizeT)),
                'expr': lambda name: name,  # at save-time
                'res': lambda name: name + '_info.zeros()',  # at eval-time
            }),
        # replace self.size(2) with self_size_2
        (r'{}.size\((\w+)\)', {
            'suffix': lambda m: '_argsize_{}'.format(*m.groups()),
            'nctype': lambda name: NamedCType(name, BaseCType(intT)),
        }),
        # replace self.numel() with self_numel
        (r'{}.numel\(\)', {
            'suffix': '_numel',
            'nctype': lambda name: NamedCType(name, BaseCType(intT)),
        }),
        # replace to_args_sizes(self) with self_args_sizes
        (r'to_args_sizes\({}\)', {
            'suffix':
            '_args_sizes',
            'nctype':
            lambda name: NamedCType(name,
                                    VectorCType(VectorCType(BaseCType(intT)))),
        }),
        # replace to_args_scalartypes(self) with self_args_scalartypes
        (r'to_args_scalartypes\({}\)', {
            'suffix':
            '_args_scalartypes',
            'nctype':
            lambda name: NamedCType(name, VectorCType(BaseCType(scalarTypeT))),
        }),
        # replace TensorGeometry(self) with self_geometry
        (r'TensorGeometry\({}\)', {
            'suffix': '_geometry',
            'nctype':
            lambda name: NamedCType(name, BaseCType(tensorGeometryT)),
        }),
        (r'{}.scalar_type\(\)', {
            'suffix': '_scalar_type',
            'nctype': lambda name: NamedCType(name, BaseCType(scalarTypeT)),
        }),
        # replace self.dim() with self_dim
        (r'{}.dim\(\)', {
            'suffix': '_dim',
            'nctype': lambda name: NamedCType(name, BaseCType(intT)),
        }),
        # replace self.strides() with self_strides
        (r'{}.strides\(\)', {
            'suffix': '_strides',
            'nctype': lambda name: NamedCType(name, BaseCType(intArrayRefT)),
            'expr': stride_expr,
        }),
    ]

    # find which arguments need to be saved
    saved: List[SavedAttribute] = []

    for nctype in nctypes:
        name = nctype.name.name if isinstance(nctype.name,
                                              SpecialArgName) else nctype.name
        # First search the formula for expressions which can be evaluated
        # when the autograd Function is created to avoid saving variables
        for regex, info in REPLACEMENTS:

            def repl(m: Match[str]) -> str:
                suffix: str = info['suffix'](m) if callable(
                    info['suffix']) else info['suffix']
                expr: str = info['expr'](name) if 'expr' in info else m.group(
                    0)
                saved.append(
                    SavedAttribute(
                        nctype=info['nctype'](name + suffix),
                        expr=expr,
                    ))
                if 'res' in info:
                    replacement: str = info['res'](name)
                    return replacement
                return name + suffix

            formula = re.sub(regex.format(name), repl, formula)

        # Find any variables which remain in the formula and save them
        if re.search(IDENT_REGEX.format(name), formula):
            saved.append(SavedAttribute(
                nctype=nctype,
                expr=name,
            ))

    return formula, tuple(saved)
Ejemplo n.º 3
0
# other scope); others are more nontrivial and may require packing/unpacking.
# Some examples of non-trivial action:
#
#   - Need the "dtype" binding?  Well, maybe "dtype" isn't available
#     in the context, instead, "options" is, and you need to extract
#     it from there.  (Gather)
#
#   - Need the "context" binding?  Well, maybe "context" isn't available
#     in the context, and you need to construct it from "dtype", "device",
#     etc.  (Scatter)
#
#   - Need the "memory_format" binding?  Well, actually, it's available
#     from both "memory_format" and "options", so you had better make sure
#     they are consistent.  (Join)

options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT)))

longVec_ctype = VectorCType(BaseCType(longT))
optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT)))
optionalScalar_ctype = OptionalCType(BaseCType(scalarT))
optionalTensor_ctype = OptionalCType(BaseCType(tensorT))

class UnsatError(RuntimeError):
    pass

# Given a set of in-scope bindings and a set of target bindings, synthesize
# a list of expressions that uses only the in-scope bindings (bindings) that
# have all of the types of goals.  You may want to use this function if
# you're generating code for a function like:
#
#   void f({args}) {
Ejemplo n.º 4
0
    def solve(goal: NamedCType, *, direct: bool) -> str:
        def direct_solve(goal: NamedCType) -> str:
            return solve(goal, direct=True)

        if goal in ctx:
            # Trivial
            return ctx[goal]

        # const & is satisfied with mutable &
        if isinstance(goal.type, ConstRefCType):
            try:
                # WARNING: not strictly decreasing; be careful not
                # to add a direct conversion that goes satisfies
                # mutable& with const&
                return solve(NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct)
            except UnsatError:
                pass

        # mutable & is satisfied with value
        if isinstance(goal.type, MutRefCType):
            try:
                return solve(NamedCType(goal.name, goal.type.elem), direct=direct)
            except UnsatError:
                pass

        if direct:
            unsat(goal)

        # For now, all of these rules are mutually exclusive.
        if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))):
            memory_format = direct_solve(
                NamedCType(SpecialArgName.possibly_redundant_memory_format, OptionalCType(BaseCType(memoryFormatT)))
            )
            # No need to join "memory_format" and "options" if the target API takes "options" directly.
            # Otherwise it will cause the redundant memory_format error.
            if options_ctype in goal_ctypes:
                return memory_format
            try:
                options = direct_solve(options_ctype)
                return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})"
            except UnsatError:
                return memory_format
        elif goal == NamedCType("options", BaseCType(tensorOptionsT)):
            dtype = direct_solve(NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))))
            pin_memory = direct_solve(NamedCType("pin_memory", OptionalCType(BaseCType(boolT))))
            device = direct_solve(NamedCType("device", OptionalCType(BaseCType(deviceT))))
            layout = direct_solve(NamedCType("layout", OptionalCType(BaseCType(layoutT))))
            return f'TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})'

        elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))):
            options = direct_solve(options_ctype)
            return f'optTypeMetaToScalarType({options}.dtype_opt())'

        elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))):
            options = direct_solve(options_ctype)
            return f'{options}.layout_opt()'

        elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))):
            options = direct_solve(options_ctype)
            return f'{options}.device_opt()'

        elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))):
            options = direct_solve(options_ctype)
            return f'{options}.pinned_memory_opt()'

        # We can always do translations from value types to reference types, like vector<int> -> IntArrayRef
        elif goal.type == BaseCType(intArrayRefT):
            return direct_solve(NamedCType(goal.name, longVec_ctype))
        elif goal.type == BaseCType(optionalIntArrayRefT):
            return direct_solve(NamedCType(goal.name, optionalLongVec_ctype))
        elif goal.type == BaseCType(optionalScalarRefT):
            return direct_solve(NamedCType(goal.name, optionalScalar_ctype))
        elif goal.type == BaseCType(optionalTensorRefT):
            return direct_solve(NamedCType(goal.name, optionalTensor_ctype))


        # Note [translation from C++ reference to value types]
        # The below cases are all for when we have an argument with a reference type,
        # and a corresponding goal with a value type.
        # These are needed when we populate the inputs to a lambda capture and we need
        # to guarantee the lifetime of each captured argument.
        # We guard it with an explicit kwarg because converting to a value type is expensive
        # (O(n)) to convert from IntArrayRef to vector<int>),
        # so the caller of translate() should be explicit that they need it.
        if allow_expensive_conversions:
            if goal.type == VectorCType(BaseCType(longT)):
                intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT))
                argname = direct_solve(intArrayRef_ctype)
                return f'{argname}.vec()'
            elif goal.type == OptionalCType(VectorCType(BaseCType(longT))):
                optionalIntArrayRef_ctype = NamedCType(goal.name, BaseCType(optionalIntArrayRefT))
                argname = direct_solve(optionalIntArrayRef_ctype)
                return f'{argname}.has_value() ? c10::make_optional({argname}->vec()) : c10::nullopt'
            elif goal.type == OptionalCType(BaseCType(scalarT)):
                optionalScalarRef_ctype = NamedCType(goal.name, BaseCType(optionalScalarRefT))
                argname = direct_solve(optionalScalarRef_ctype)
                return f'{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt'
            elif goal.type == OptionalCType(BaseCType(scalarT)):
                optionalTensorRef_ctype = NamedCType(goal.name, BaseCType(optionalTensorRefT))
                argname = direct_solve(optionalTensorRef_ctype)
                return f'{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt'
            # Technically, we also need to handle cases of C++ containers holding reference types.
            # But there currently aren't any ops that require lambda capture codegen
            # With arguments like std::vector<IntArrayRef>.
            # If that changes, we'll have to add the translation here.

        unsat(goal)
Ejemplo n.º 5
0
# - The lambda capture has to convert reference types to value types
# - While the forward lambda just directly calls into the at::_ops API
#   (following the dispatcher convention), the logic here for the reverse lambda
#   is responsible for generating both the call-site, and the declarations
#   (which are implemented manually in the at::functionalization::impl namespace).

# The lambdas generated for each view op in the functionalization pass are of the form
# [capture_arguments](outer_arguments) -> returns_type {
#     return name(inner_arguments);
# }

# Define some specific lambda input arguments.
base_binding = Binding(name='base',
                       nctype=NamedCType(name='base',
                                         type=ConstRefCType(
                                             BaseCType(tensorT))),
                       argument=Argument(name='base',
                                         type=BaseType(BaseTy.Tensor),
                                         default=None,
                                         annotation=None),
                       default=None)
mutated_view_binding = Binding(name='mutated_view',
                               nctype=NamedCType(name='mutated_view',
                                                 type=ConstRefCType(
                                                     BaseCType(tensorT))),
                               argument=Argument(name='base',
                                                 type=BaseType(BaseTy.Tensor),
                                                 default=None,
                                                 annotation=None),
                               default=None)
mutated_view_idx_binding = Binding(name='mutated_view_idx',
Ejemplo n.º 6
0
} else if (prop.isIntegral(/*includeBool=*/false)) {
  return PyLong_FromLong(prop.to<int64_t>());
} else if (prop.isBoolean()) {
  if (prop.to<bool>()) {
    Py_RETURN_TRUE;
  } else {
    Py_RETURN_FALSE;
  }
} else {
  PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type");
  return nullptr;
}
"""

MISC_GETTER_DEFS = {
    OptionalCType(BaseCType(intT)): (GETTER_DEFINITION_OPT, GETTER_BODY_INT64_T),
    BaseCType(doubleT): (GETTER_DEFINITION, GETTER_BODY_DOUBLE),
    OptionalCType(BaseCType(doubleT)): (GETTER_DEFINITION_OPT, GETTER_BODY_DOUBLE),
    BaseCType(boolT): (GETTER_DEFINITION, GETTER_BODY_BOOL),
    BaseCType(scalarT): (GETTER_DEFINITION, GETTER_BODY_SCALAR),
    OptionalCType(BaseCType(scalarT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SCALAR),
}

# These functions have backwards which cannot be traced, and so must have
# their backward functions traced opaquely.
# VIEW_FUNCTIONS are not traceable because they use as_strided, which
# has an untraceable backwards, see
# https://github.com/pytorch/pytorch/issues/4250
# TODO: This is probably not exhaustive, but it's a start
UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
Ejemplo n.º 7
0
    def check_tensorimpl_and_storage(call: str,
                                     unpacked_bindings: List[Binding]) -> str:
        # See NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
        stmts_before_call: List[str] = []
        stmts_after_call: List[str] = []

        if cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
            return call

        # Check properties of inputs (enforce (1))
        for unpacked_binding in unpacked_bindings:
            arg = unpacked_binding.name
            noref_cpp_type = unpacked_binding.nctype.type.remove_const_ref()
            if noref_cpp_type == BaseCType(tensorListT):
                stmts_before_call += [
                    SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
                    SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg)
                ]
                stmts_after_call += [
                    ENFORCE_SAME_TENSORLIST_STORAGE.substitute(
                        tensorlist_name=arg),
                    ENFORCE_SAME_TENSORLIST_IMPL.substitute(
                        tensorlist_name=arg)
                ]
            elif noref_cpp_type == ListCType(OptionalCType(
                    BaseCType(tensorT))):
                stmts_before_call += [
                    SAVE_OPTIONALTENSORLIST_STORAGE.substitute(
                        tensorlist_name=arg),
                    SAVE_OPTIONALTENSORLIST_IMPL.substitute(
                        tensorlist_name=arg)
                ]
                stmts_after_call += [
                    ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute(
                        tensorlist_name=arg),
                    ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute(
                        tensorlist_name=arg)
                ]
            elif noref_cpp_type == BaseCType(tensorT):
                stmts_before_call += [
                    SAVE_TENSOR_STORAGE.substitute(tensor_name=arg),
                    SAVE_TENSOR_IMPL.substitute(tensor_name=arg)
                ]
                stmts_after_call += [
                    ENFORCE_SAME_TENSOR_STORAGE.substitute(
                        tensor_name=arg, out_tensor_name=arg),
                    ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg)
                ]

        assert (stmts_before_call
                and stmts_after_call) or (not stmts_before_call
                                          and not stmts_after_call)

        # Check properties of outputs (enforce (2), (3))
        if not f.func.kind() in (SchemaKind.inplace, SchemaKind.out):
            base_name = f.func.name.name.base  # TODO: should be str(f.func.name.name)?
            aliased_arg_name = ALL_VIEW_FUNCTIONS.get(base_name, None)
            if aliased_arg_name is not None:
                aliased_arg_name = unpacked_name(aliased_arg_name)
            for i, (ret, ret_name) in enumerate(
                    zip(f.func.returns, cpp.return_names(f))):
                noref_cpp_type = cpp.return_type(ret).remove_const_ref()
                if noref_cpp_type == BaseCType(tensorT):
                    if aliased_arg_name is not None:
                        assert i == 0, "Expect non-CompositeImplicitAutograd view function {base} to return single output"
                        stmts_after_call += [
                            ENFORCE_SAME_TENSOR_STORAGE.substitute(
                                tensor_name=aliased_arg_name,
                                out_tensor_name=ret_name)
                        ]
                    else:
                        if type_wrapper_name(
                                f) not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT:
                            stmts_after_call += [
                                ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE.
                                substitute(tensor_name=ret_name,
                                           fn_name=type_wrapper_name(f))
                            ]

                    if type_wrapper_name(
                            f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT:
                        stmts_after_call += [
                            ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE.
                            substitute(tensor_name=ret_name,
                                       fn_name=type_wrapper_name(f))
                        ]

                # Currently we don't have any functions that return the following types, but
                # we should update the checks once we do
                elif noref_cpp_type == ListCType(
                        OptionalCType(BaseCType(tensorT))):
                    raise AssertionError(
                        f"Please add use_count checks for {noref_cpp_type}")
                elif noref_cpp_type == BaseCType(tensorListT):
                    raise AssertionError(
                        f"Please add use_count checks for {noref_cpp_type}")

        if stmts_before_call and stmts_after_call:
            call = RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_before_call) + \
                call + \
                RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_after_call)
        return call
Ejemplo n.º 8
0
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
    # If it's a value type, do the value type translation
    r = valuetype_type(t, binds=binds)
    if r is not None:
        return r

    if isinstance(t, BaseType):
        if t.name == BaseTy.Tensor:
            if mutable:
                return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
            else:
                return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
        elif t.name == BaseTy.Scalar:
            return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
        else:
            raise AssertionError(f"base type should have been value type {t}")
    elif isinstance(t, OptionalType):
        if str(t.elem) == 'Tensor':
            if mutable:
                return NamedCType(binds, MutRefCType(
                    BaseCType(tensorT)))  # TODO: fix this discrepancy
            else:
                return NamedCType(
                    binds, ConstRefCType(OptionalCType(BaseCType(tensorT))))
        elif str(t.elem) == 'Scalar':
            return NamedCType(binds,
                              ConstRefCType(OptionalCType(BaseCType(scalarT))))
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
        return NamedCType(binds, OptionalCType(elem.type))
    elif isinstance(t, ListType):
        # TODO: remove these special cases, ArrayRef fallthrough works fine
        if str(t.elem) == 'int':
            return NamedCType(binds, BaseCType(intArrayRefT))
        elif str(t.elem) == 'Tensor':
            return NamedCType(binds, BaseCType(tensorListT))
        elif str(t.elem) == 'Scalar':
            return NamedCType(binds, ArrayRefCType(BaseCType(scalarT)))
        elif str(t.elem) == 'Dimname':
            return NamedCType(binds, BaseCType(dimnameListT))
        elif str(t.elem) == 'Tensor?':
            return NamedCType(
                binds,
                ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))))
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
        return NamedCType(binds, ArrayRefCType(elem.type))
    else:
        raise AssertionError(f"unrecognized type {repr(t)}")
Ejemplo n.º 9
0
    def solve(goal: NamedCType, *, direct: bool) -> str:
        def direct_solve(goal: NamedCType) -> str:
            return solve(goal, direct=True)

        if goal in ctx:
            # Trivial
            return ctx[goal]

        # const & is satisfied with mutable &
        if isinstance(goal.type, ConstRefCType):
            try:
                # WARNING: not strictly decreasing; be careful not
                # to add a direct conversion that goes satisfies
                # mutable& with const&
                return solve(NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct)
            except UnsatError:
                pass

        # mutable & is satisfied with value
        if isinstance(goal.type, MutRefCType):
            try:
                return solve(NamedCType(goal.name, goal.type.elem), direct=direct)
            except UnsatError:
                pass

        if direct:
            unsat(goal)

        # For now, all of these rules are mutually exclusive.
        if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))):
            memory_format = direct_solve(
                NamedCType(SpecialArgName.possibly_redundant_memory_format, OptionalCType(BaseCType(memoryFormatT)))
            )
            # No need to join "memory_format" and "options" if the target API takes "options" directly.
            # Otherwise it will cause the redundant memory_format error.
            if options_ctype in goal_ctypes:
                return memory_format
            try:
                options = direct_solve(options_ctype)
                return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})"
            except UnsatError:
                return memory_format

        elif goal == NamedCType("options", BaseCType(tensorOptionsT)):
            dtype = direct_solve(NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))))
            pin_memory = direct_solve(NamedCType("pin_memory", OptionalCType(BaseCType(boolT))))
            device = direct_solve(NamedCType("device", OptionalCType(BaseCType(deviceT))))
            layout = direct_solve(NamedCType("layout", OptionalCType(BaseCType(layoutT))))
            return f'TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})'

        elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))):
            options = direct_solve(options_ctype)
            return f'optTypeMetaToScalarType({options}.dtype_opt())'

        elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))):
            options = direct_solve(options_ctype)
            return f'{options}.layout_opt()'

        elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))):
            options = direct_solve(options_ctype)
            return f'{options}.device_opt()'

        elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))):
            options = direct_solve(options_ctype)
            return f'{options}.pinned_memory_opt()'

        unsat(goal)
Ejemplo n.º 10
0
def translate(
    bindings: Sequence[Union[Expr, Binding]],
    goals: Sequence[Union[NamedCType, Binding]],
    *, method: bool = False
) -> List[Expr]:

    binding_exprs: List[Expr] = []
    for b in bindings:
        if isinstance(b, Binding):
            binding_exprs.append(Expr(
                expr=b.name,
                type=b.nctype,
            ))
        else:
            binding_exprs.append(b)

    goal_ctypes: List[NamedCType] = []
    for g in goals:
        if isinstance(g, Binding):
            goal_ctypes.append(g.nctype)
        else:
            goal_ctypes.append(g)

    # Add all the bindings to the context
    ctx: Dict[NamedCType, str] = {}
    for b in binding_exprs:
        ctx[b.type] = b.expr

        # While we're at it, do some simple forward inference, looking through
        # constructors.
        # TODO: My kingdom for a pattern matcher
        # https://www.python.org/dev/peps/pep-0634/
        # TODO: This could get us in recomputation trouble if b.expr is nontrivial
        t = b.type
        if isinstance(t, ConstRefCType) and isinstance(t.elem, OptionalCType) and \
                isinstance(t.elem.elem, BaseCType) and str(t.elem.elem.type) == 'at::Tensor':
            ctx[NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))] = \
                f'({b.expr}.has_value() ? *{b.expr} : at::Tensor())'

    # Add implicit bindings if the generated code is inside a Tensor method
    if method:
        ctx[NamedCType("self", MutRefCType(BaseCType(tensorT)))] = "const_cast<Tensor&>(*this)"
        ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "const_cast<Tensor&>(*this)"
        # This is better!  Byte-for-byte compat
        # ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this"

    def unsat(goal: NamedCType) -> NoReturn:
        ctx_desc = '\n'.join(f"  {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items())
        raise UnsatError(f'''
Failed to synthesize the expression "{goal.cpp_type()} {goal.name}".
When I failed, the following bindings were available in the context:

{ctx_desc}

This probably means there is a missing rule in the rules of tools.codegen.api.translate.
Check this module for more information.
''')

    # A shitty backtracking search implementation.  It's shitty because it
    # doesn't actually do backtracing or search. In particular, if
    # direct=True, we won't try to do any fancy synthesis, just trivial
    # conversions (e.g., "T a" is OK for "const T& a").  So all of the
    # existing rules in this function simply try to solve immediately,
    # and bail if things don't work out.
    def solve(goal: NamedCType, *, direct: bool) -> str:
        def direct_solve(goal: NamedCType) -> str:
            return solve(goal, direct=True)

        if goal in ctx:
            # Trivial
            return ctx[goal]

        # const & is satisfied with mutable &
        if isinstance(goal.type, ConstRefCType):
            try:
                # WARNING: not strictly decreasing; be careful not
                # to add a direct conversion that goes satisfies
                # mutable& with const&
                return solve(NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct)
            except UnsatError:
                pass

        # mutable & is satisfied with value
        if isinstance(goal.type, MutRefCType):
            try:
                return solve(NamedCType(goal.name, goal.type.elem), direct=direct)
            except UnsatError:
                pass

        if direct:
            unsat(goal)

        # For now, all of these rules are mutually exclusive.
        if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))):
            memory_format = direct_solve(
                NamedCType(SpecialArgName.possibly_redundant_memory_format, OptionalCType(BaseCType(memoryFormatT)))
            )
            # No need to join "memory_format" and "options" if the target API takes "options" directly.
            # Otherwise it will cause the redundant memory_format error.
            if options_ctype in goal_ctypes:
                return memory_format
            try:
                options = direct_solve(options_ctype)
                return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})"
            except UnsatError:
                return memory_format

        elif goal == NamedCType("options", BaseCType(tensorOptionsT)):
            dtype = direct_solve(NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))))
            pin_memory = direct_solve(NamedCType("pin_memory", OptionalCType(BaseCType(boolT))))
            device = direct_solve(NamedCType("device", OptionalCType(BaseCType(deviceT))))
            layout = direct_solve(NamedCType("layout", OptionalCType(BaseCType(layoutT))))
            return f'TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})'

        elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))):
            options = direct_solve(options_ctype)
            return f'optTypeMetaToScalarType({options}.dtype_opt())'

        elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))):
            options = direct_solve(options_ctype)
            return f'{options}.layout_opt()'

        elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))):
            options = direct_solve(options_ctype)
            return f'{options}.device_opt()'

        elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))):
            options = direct_solve(options_ctype)
            return f'{options}.pinned_memory_opt()'

        unsat(goal)

    return [Expr(solve(g, direct=False), g) for g in goal_ctypes]
Ejemplo n.º 11
0
# other scope); others are more nontrivial and may require packing/unpacking.
# Some examples of non-trivial action:
#
#   - Need the "dtype" binding?  Well, maybe "dtype" isn't available
#     in the context, instead, "options" is, and you need to extract
#     it from there.  (Gather)
#
#   - Need the "context" binding?  Well, maybe "context" isn't available
#     in the context, and you need to construct it from "dtype", "device",
#     etc.  (Scatter)
#
#   - Need the "memory_format" binding?  Well, actually, it's available
#     from both "memory_format" and "options", so you had better make sure
#     they are consistent.  (Join)

options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT)))

class UnsatError(RuntimeError):
    pass

# Given a set of in-scope bindings and a set of target bindings, synthesize
# a list of expressions that uses only the in-scope bindings (bindings) that
# have all of the types of goals.  You may want to use this function if
# you're generating code for a function like:
#
#   void f({args}) {
#     g({exprs}); // g is a different API
#   }
#
# and you need to generate "exprs".
#
Ejemplo n.º 12
0
# other scope); others are more nontrivial and may require packing/unpacking.
# Some examples of non-trivial action:
#
#   - Need the "dtype" binding?  Well, maybe "dtype" isn't available
#     in the context, instead, "options" is, and you need to extract
#     it from there.  (Gather)
#
#   - Need the "context" binding?  Well, maybe "context" isn't available
#     in the context, and you need to construct it from "dtype", "device",
#     etc.  (Scatter)
#
#   - Need the "memory_format" binding?  Well, actually, it's available
#     from both "memory_format" and "options", so you had better make sure
#     they are consistent.  (Join)

options_ctype = ConstRefCType(BaseCType("TensorOptions", "options"))

class UnsatError(RuntimeError):
    pass

# Given a set of in-scope bindings and a set of target bindings, synthesize
# a list of expressions that uses only the in-scope bindings (bindings) that
# have all of the types of goals.  You may want to use this function if
# you're generating code for a function like:
#
#   void f({args}) {
#     g({exprs}); // g is a different API
#   }
#
# and you need to generate "exprs".
#
Ejemplo n.º 13
0
def process_ir_type(
        typ: Type) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]:
    """
    This function takes a type from NativeFunctions and converts it for use with
    lazy tensor codegen.

    Type conversion for lazy currently consists of
     (1) changing at::Tensors into lazy::Values
     (2) wrapping everything in a BaseCType
     (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef)

    (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.)
    There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like'

    This is incomplete- there are assertions in places that it's expected to need to add
    more types as the codegen is used with more operators.
    """
    if isinstance(typ, BaseType):
        if typ.name == BaseTy.Tensor:
            return BaseCType(valueT)
        elif typ.name == BaseTy.Scalar:
            # at::scalar has special handling,
            # and is wrapped in an lazy::Value just like at::tensor
            return BaseCType(valueT)
        elif typ.name == BaseTy.ScalarType:
            return BaseCType(scalarTypeT)
        elif typ.name == BaseTy.int:
            return BaseCType(longT)
        elif typ.name == BaseTy.bool:
            return BaseCType(boolT)
        elif typ.name == BaseTy.float:
            return BaseCType(doubleT)
        elif typ.name == BaseTy.str:
            return BaseCType(stringT)
        elif typ.name == BaseTy.Device:
            return BaseCType(deviceT)
        elif typ.name == BaseTy.Layout:
            return BaseCType(layoutT)
        elif typ.name == BaseTy.MemoryFormat:
            return BaseCType(memoryFormatT)
        else:
            raise AssertionError(f"TODO add support for type {repr(typ)}")
    elif isinstance(typ, OptionalType):
        return OptionalCType(process_ir_type(typ.elem))
    elif isinstance(typ, ListType):
        if str(typ.elem) == 'Tensor?':
            # TODO(whc) is this actually correct? or should it use a Vector like above
            return ListCType(OptionalCType(BaseCType(valueT)))
        elif str(typ.elem) == 'Tensor':
            # this is a TensorList which comes in from GetTensorList as a Value
            return BaseCType(tensorListValueT)
        else:
            return VectorCType(process_ir_type(typ.elem))
    else:
        raise AssertionError(f"unrecognized type {repr(typ)}")
Ejemplo n.º 14
0
} else if (prop.isIntegral(/*includeBool=*/false)) {
  return PyLong_FromLong(prop.to<int64_t>());
} else if (prop.isBoolean()) {
  if (prop.to<bool>()) {
    Py_RETURN_TRUE;
  } else {
    Py_RETURN_FALSE;
  }
} else {
  PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type");
  return nullptr;
}
"""

MISC_GETTER_DEFS = {
    OptionalCType(BaseCType(longT)):
    (GETTER_DEFINITION_OPT, GETTER_BODY_INT64_T),
    BaseCType(doubleT): (GETTER_DEFINITION, GETTER_BODY_DOUBLE),
    OptionalCType(BaseCType(doubleT)):
    (GETTER_DEFINITION_OPT, GETTER_BODY_DOUBLE),
    BaseCType(boolT): (GETTER_DEFINITION, GETTER_BODY_BOOL),
    BaseCType(scalarT): (GETTER_DEFINITION, GETTER_BODY_SCALAR),
    OptionalCType(BaseCType(scalarT)):
    (GETTER_DEFINITION_OPT, GETTER_BODY_SCALAR),
}

# These functions have backwards which cannot be traced, and so must have
# their backward functions traced opaquely.
# VIEW_FUNCTIONS are not traceable because they use as_strided, which
# has an untraceable backwards, see
# https://github.com/pytorch/pytorch/issues/4250
Ejemplo n.º 15
0
    def gen_one(self, f: NativeFunction) -> Optional[str]:
        assert not f.manual_kernel_registration

        if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f):
            return None

        # TODO: Now, there is something interesting going on here.  In the code below,
        # we generate CompositeExplicitAutograd implementations of functional and inplace
        # based on the out implementation.  But in fact, out is definable by
        # functional too (just not very efficiently), and this is honestly the
        # MORE likely situation for a backend implementor.  How do we pick?
        # Well, taking a page from Haskell type classes and default methods,
        # we could conceivably register a circular definition (out in terms
        # of functional, and functional in terms of out) and just require
        # someone to implement one or the other.  We'd have to do a little bit
        # of work to not register one of these "weak" definitions unless there
        # is a strong definition somewhere in the DAG!  So it's not implemented yet.
        if self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd and f.func.kind() is SchemaKind.out:
            # Never generate a default implementation for out, that's what you
            # have to define as a backend implementor
            return None

        # Note [Direct dispatch bindings]
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Signature of the non-dispatched function we'll expose in a header
        # (e.g., at::cpu::add).  We don't generate methods (TODO: do this
        # when CPUTensor class is a thing); nor do we generate fallback
        # bindings for manual_cpp_binding functions.
        cpp_sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=False)

        # Signature of the wrapper function we'll register to the dispatcher
        sig = NativeSignature(f.func, prefix="wrapper_")

        if self.target is Target.NAMESPACED_DECLARATION:
            result = f"TORCH_API {cpp_sig_group.signature.decl()};\n"
            if cpp_sig_group.faithful_signature is not None:
                result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n"
            return result

        elif self.target is Target.NAMESPACED_DEFINITION:
            def generate_defn(cpp_sig: CppSignature) -> str:
                return f"""
{cpp_sig.defn()} {{
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
}}
"""
            result = generate_defn(cpp_sig_group.signature)
            if cpp_sig_group.faithful_signature is not None:
                result += generate_defn(cpp_sig_group.faithful_signature)
            return result

        elif self.target is Target.ANONYMOUS_DEFINITION:

            k = f.func.kind()

            # Construct the body of the wrapper function with signature sig
            sig_body = []
            # We'll use context to keep track of any variables we've brought
            # into scope while generating code
            context: List[Union[Binding, Expr]] = list(sig.arguments())

            # Initialize the class corresponding to this structured
            # operator; feeding it the output argument(s) if it is known
            if self.backend_index.dispatch_key is DispatchKey.Meta:
                class_name = f"structured_{meta.name(self.g)}_meta_{k.name}"
                parent_class = f"at::meta::{meta.name(self.g)}"
            elif self.backend_index.dispatch_key is DispatchKey.CompositeExplicitAutograd:
                # TODO: dedup this branch
                class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}"
                parent_class = f"at::meta::{meta.name(self.g)}"
            else:
                metadata = self.backend_index.get_kernel(self.g)
                assert metadata is not None
                class_name = f"structured_{metadata.kernel}_{k.name}"
                parent_class = f"{self.cpp_namespace}::structured_{metadata.kernel}"

            if is_cuda_dispatch_key(self.backend_index.dispatch_key):
                device_check_args = itertools.chain(
                    f.func.arguments.out,
                    f.func.arguments.flat_positional
                )
                sig_body.append(RegisterDispatchKey.gen_device_check(f.device_check, list(device_check_args), sig.name()))

            if k is SchemaKind.functional:
                sig_body.append(f"{class_name} op;")
            elif k is SchemaKind.inplace:
                sig_body.append(f"{class_name} op(self);")
            elif k is SchemaKind.out:
                out_args_str = ', '.join(a.name for a in f.func.arguments.out)
                sig_body.append(f"{class_name} op({out_args_str});")

            # Translate the input native arguments into structured
            # arguments for the meta call
            meta_exprs = ', '.join(
                e.expr for e in translate(
                    context,
                    structured.meta_arguments(self.g),
                    method=False
                )
            )
            sig_body.append(f"op.meta({meta_exprs});")

            # After running meta, op.outputs_ is guaranteed to be valid;
            # add it to the context
            out_args = structured.out_arguments(self.g)
            for i, out_arg in enumerate(out_args):
                assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type
                context.append(Expr(
                    expr=f"op.outputs_[{i}]",
                    # TODO: Stop hardcoding that the output type is a Tensor.  Note
                    # that for the codegen here this is fine because outputs_ is
                    # hardcoded to be tensor already
                    type=NamedCType(out_arg.nctype.name, MutRefCType(BaseCType(tensorT)))
                ))

            # With the expanded context, do the impl call (if not a meta
            # function)
            if self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd:
                # TODO: https://github.com/pytorch/pytorch/issues/53023
                out_sig_group = CppSignatureGroup.from_native_function(
                    self.g.out, method=False, fallback_binding=f.manual_cpp_binding)
                out_sig = out_sig_group.most_faithful_signature()
                api_name = out_sig.name()
                out_exprs = ', '.join(
                    e.expr for e in translate(
                        context,
                        out_sig.arguments(),
                        method=False
                    )
                )
                # TODO: I think this means structured won't work with method
                # only functions (but maybe you're saved by faithful? iunno.)
                # NB: Originally I wrote this as an at::redispatch call, but
                # I got in trouble because that meant I needed a DispatchKeySet
                # in the wrapper function, which meant I needed a DispatchKeySet
                # in the DispatchKeyFunctions declarations, but the defined API
                # there does NOT permit a dispatch key set.  I think you can
                # probably unwind this by calling some function to do the TLS
                # fetch and get the DispatchKeySet when you don't have it, but
                # I didn't do it for this version
                sig_body.append(f"at::{api_name}({out_exprs});")
            elif self.backend_index.dispatch_key != DispatchKey.Meta:
                impl_exprs = ', '.join(
                    e.expr for e in translate(
                        context,
                        structured.impl_arguments(self.g),
                        method=False
                    )
                )
                sig_body.append(f"op.impl({impl_exprs});")

            # Destructively return the final tensors
            # TODO: Do this in translate instead
            if k is SchemaKind.functional:
                if len(f.func.returns) == 1:
                    ret_expr = "std::move(op.outputs_[0])"  # small optimization
                else:
                    moved = ', '.join(f"std::move(op.outputs_[{i}])" for i in range(len(f.func.returns)))
                    ret_expr = f"std::make_tuple({moved})"
            elif k is SchemaKind.inplace:
                ret_expr = "self"
            elif k is SchemaKind.out:
                if len(f.func.returns) == 1:
                    ret_expr = f.func.arguments.out[0].name
                else:
                    refs = ', '.join(a.name for a in f.func.arguments.out)
                    ret_expr = f"std::forward_as_tuple({refs})"
            sig_body.append(f"return {ret_expr};")

            sig_body_str = "\n".join(sig_body)

            # For an overview of what this template code looks like, see
            # https://github.com/pytorch/rfcs/pull/9
            return f"""\
{self.gen_class(
f, k,
class_name=class_name,
parent_class=parent_class,
generate_super=self.g.out.structured_inherits is not None
)}

{sig.defn()} {{
{sig_body_str}
}}
"""

        elif self.target is Target.REGISTRATION:
            return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));'
        else:
            assert_never(self.target)
            # Silence mypy's "Missing return statement" error
            return None
Ejemplo n.º 16
0
    def save_var(var: SavedAttribute, is_output: bool) -> None:
        name = var.nctype.name
        type = var.nctype.type
        should_append_getsetdef = True
        should_append_raw_getsetdef = False

        if type == BaseCType(tensorT) or type == OptionalCType(BaseCType(tensorT)) or \
                type == MutRefCType(OptionalCType(BaseCType(tensorT))) or \
                (type == BaseCType(scalarT) and is_output):
            saved_variables.append(f'SavedVariable {name}_;')
            release_variables.append(f'{name}_.reset_data();')
            ptr = 'shared_from_this()' if is_output else ''
            unpack.append(f'auto {name} = {name}_.unpack({ptr});')
            getter_definitions.append(
                GETTER_DEFINITION_SAVEDVAR.substitute(
                    op=info.op, name=name, body=GETTER_BODY_SAVEDVAR))
            getter_definitions.append(
                GETTER_DEFINITION_RAW_SAVEDVAR.substitute(
                    op=info.op, name=name, body=GETTER_BODY_RAW_SAVEDVAR))
            should_append_raw_getsetdef = True
        elif type == BaseCType(tensorListT):
            saved_variables.append(f'std::vector<SavedVariable> {name}_;')
            saved_variables.append(f'bool {name}_released_ = false;')
            # Just clear() is sufficient, we don't need to loop and clear each variable.
            # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
            release_variables.append(f'{name}_.clear();')
            release_variables.append(f'{name}_released_ = true;')
            unpack.append(f'auto {name} = unpack_list({name}_);')
            asserts.append(
                f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);')
            getter_definitions.append(
                GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
                    op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR))
            getter_definitions.append(
                GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute(
                    op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR))
            should_append_raw_getsetdef = True
        elif type == ListCType(OptionalCType(BaseCType(tensorT))):
            saved_variables.append(f'std::vector<SavedVariable> {name}_;')
            saved_variables.append(f'bool {name}_released_ = false;')
            # Just clear() is sufficient, we don't need to loop and clear each variable.
            # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
            release_variables.append(f'{name}_.clear();')
            release_variables.append(f'{name}_released_ = true;')
            unpack.append(f'auto {name} = unpack_opt_list({name}_);')
            asserts.append(
                f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);')
            getter_definitions.append(
                GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
                    op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR))
            getter_definitions.append(
                GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute(
                    op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR))
            should_append_raw_getsetdef = True
        elif type == BaseCType(intArrayRefT):
            saved_variables.append(f'std::vector<int64_t> {name};')
            getter_definitions.append(
                GETTER_DEFINITION.substitute(op=info.op,
                                             name=name,
                                             body=GETTER_BODY_ARRAYREF_LONG))
        elif type == BaseCType(optionalIntArrayRefT):
            saved_variables.append(f'c10::OptionalArray<int64_t> {name};')
            getter_definitions.append(
                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG))
        elif type == OptionalCType(BaseCType(intArrayRefT)):
            saved_variables.append(f'c10::OptionalArray<int64_t> {name};')
            getter_definitions.append(
                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG))
        elif type == OptionalCType(ArrayRefCType(BaseCType(doubleT))):
            saved_variables.append(f'c10::OptionalArray<double> {name};')
            getter_definitions.append(
                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_DOUBLE))
        elif type == BaseCType(longT):
            saved_variables.append(f'{type.cpp_type()} {name} = 0;')
            getter_definitions.append(
                GETTER_DEFINITION.substitute(op=info.op,
                                             name=name,
                                             body=GETTER_BODY_INT64_T))
        elif type == BaseCType(stringT):
            saved_variables.append(f'std::string {name};')
            getter_definitions.append(
                GETTER_DEFINITION.substitute(op=info.op,
                                             name=name,
                                             body=GETTER_BODY_STRING))
        elif type == OptionalCType(BaseCType(stringT)):
            saved_variables.append(f'c10::optional<std::string> {name};')
            getter_definitions.append(
                GETTER_DEFINITION_OPT.substitute(op=info.op,
                                                 name=name,
                                                 body=GETTER_BODY_STRING))
        else:
            saved_variables.append(f'{type.cpp_type()} {name};')

            if type in MISC_GETTER_DEFS:
                getter_def, body = MISC_GETTER_DEFS[type]
                getter_definitions.append(
                    getter_def.substitute(op=info.op, name=name, body=body))
            else:
                # Types we don't expose python bindings to yet:
                #   TypeAndSize, at::ScalarType, TensorOptions, TensorGeometry,
                #   std::vector<std::vector<int64_t>>, std::vector<at::ScalarType>
                should_append_getsetdef = False

        if should_append_getsetdef:
            py_getsetdef_structs.append(
                PY_GETSETDEF_STRUCT.substitute(op=info.op, name=name))
        if should_append_raw_getsetdef:
            py_getsetdef_structs.append(
                PY_RAW_GETSETDEF_STRUCT.substitute(op=info.op, name=name))
Ejemplo n.º 17
0
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType:
    # If it's a value type, do the value type translation
    r = valuetype_type(t, binds=binds)
    if r is not None:
        return r

    if isinstance(t, BaseType):
        if t.name == BaseTy.Tensor:
            if mutable:
                return MutRefCType(BaseCType('Tensor', binds))
            else:
                return ConstRefCType(BaseCType('Tensor', binds))
        elif t.name == BaseTy.Scalar:
            return ConstRefCType(BaseCType('Scalar', binds))
        else:
            raise AssertionError(f"base type should have been value type {t}")
    elif isinstance(t, OptionalType):
        if str(t.elem) == 'Tensor':
            if mutable:
                return MutRefCType(BaseCType(
                    'Tensor', binds))  # TODO: fix this discrepancy
            else:
                return ConstRefCType(OptionalCType(BaseCType('Tensor', binds)))
        elif str(t.elem) == 'Scalar':
            return ConstRefCType(OptionalCType(BaseCType('Scalar', binds)))
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
        return OptionalCType(elem)
    elif isinstance(t, ListType):
        # TODO: remove these special cases, ArrayRef fallthrough works fine
        # NB: CType throws away ArrayRef structure because it is not currently
        # relevant in translation.  When it becomes relevant, need to add back
        if str(t.elem) == 'int':
            return BaseCType("IntArrayRef", binds)
        elif str(t.elem) == 'Tensor':
            return BaseCType("TensorList", binds)
        elif str(t.elem) == 'Scalar':
            return BaseCType("ArrayRef<Scalar>", binds)
        elif str(t.elem) == 'Dimname':
            return BaseCType("DimnameList", binds)
        elif str(t.elem) == 'Tensor?':
            return ConstRefCType(
                BaseCType("c10::List<c10::optional<Tensor>>", binds))
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
        # TODO: explicitly qualify namespace here
        return BaseCType(f"ArrayRef<{elem.cpp_type()}>", binds)
    else:
        raise AssertionError(f"unrecognized type {repr(t)}")