Example #1
0
def process_ir_type(
    typ: Type, properties: "LazyIrProperties"
) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]:
    """
    This function takes a type from NativeFunctions and converts it for use with
    lazy tensor codegen.

    Type conversion for lazy currently consists of
     (1) changing at::Tensors into lazy::Values
     (2) wrapping everything in a BaseCType
     (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef)

    (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.)
    There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like'

    This is incomplete- there are assertions in places that it's expected to need to add
    more types as the codegen is used with more operators.
    """
    if isinstance(typ, BaseType):
        if typ.name == BaseTy.Tensor:
            return BaseCType(getValueT())
        elif typ.name == BaseTy.Scalar:
            if properties.TreatScalarsAsConstants:
                return BaseCType(scalarT)
            # at::scalar has special handling,
            # and is wrapped in an lazy::Value just like at::tensor
            return BaseCType(getValueT())
        elif typ.name == BaseTy.ScalarType:
            return BaseCType(scalarTypeT)
        elif typ.name == BaseTy.int:
            return BaseCType(longT)
        elif typ.name == BaseTy.SymInt:
            return BaseCType(getValueT())
        elif typ.name == BaseTy.bool:
            return BaseCType(boolT)
        elif typ.name == BaseTy.float:
            return BaseCType(doubleT)
        elif typ.name == BaseTy.str:
            return BaseCType(stringT)
        elif typ.name == BaseTy.Device:
            return BaseCType(deviceT)
        elif typ.name == BaseTy.Layout:
            return BaseCType(layoutT)
        elif typ.name == BaseTy.MemoryFormat:
            return BaseCType(memoryFormatT)
        else:
            raise AssertionError(f"TODO add support for type {repr(typ)}")
    elif isinstance(typ, OptionalType):
        return OptionalCType(process_ir_type(typ.elem, properties))
    elif isinstance(typ, ListType):
        if str(typ.elem) == "Tensor?":
            # TODO(whc) is this actually correct? or should it use a Vector like above
            return ListCType(OptionalCType(BaseCType(getValueT())))
        elif str(typ.elem) == "Tensor":
            # this is a TensorList which comes in from GetTensorList as a Value
            return BaseCType(tensorListValueT)
        else:
            return VectorCType(process_ir_type(typ.elem, properties))
    else:
        raise AssertionError(f"unrecognized type {repr(typ)}")
Example #2
0
def returntype_type(t: Type, *, mutable: bool) -> CType:
    # placeholder is ignored
    r = valuetype_type(t, binds="__placeholder__")
    if r is not None:
        return r.type

    if isinstance(t, BaseType):
        if t.name == BaseTy.Tensor:
            if mutable:
                if local.use_const_ref_for_mutable_tensors():
                    return ConstRefCType(BaseCType(tensorT))
                else:
                    return MutRefCType(BaseCType(tensorT))
            else:
                # Note [Tensor Copy Returns]
                # Currently, we use "Argument.is_write" to determine
                # whether or not Tensor return types should be copies or references.
                # If that ever changes, take a look at other locations of this note!
                return BaseCType(tensorT)
        elif t.name == BaseTy.Scalar:
            return BaseCType(scalarT)
    elif isinstance(t, ListType):
        assert (
            not mutable
        ), "Native functions should never return a mutable tensor list. They should return void."
        elem = returntype_type(t.elem, mutable=False)
        assert t.size is None, f"fixed size list returns not supported: {t}"
        return VectorCType(elem)

    raise AssertionError(f"unrecognized return type {t}")
Example #3
0
def argumenttype_type(t: Type,
                      *,
                      mutable: bool,
                      binds: ArgName,
                      remove_non_owning_ref_types: bool = False) -> NamedCType:
    # If it's a value type, do the value type translation
    r = valuetype_type(t,
                       binds=binds,
                       remove_non_owning_ref_types=remove_non_owning_ref_types)
    if r is not None:
        return r

    if isinstance(t, BaseType):
        if t.name == BaseTy.Tensor:
            if mutable and not local.use_const_ref_for_mutable_tensors():
                return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
            else:
                return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
        elif t.name == BaseTy.Scalar:
            return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
        else:
            raise AssertionError(f"base type should have been value type {t}")
    elif isinstance(t, OptionalType):
        if str(t.elem) == "Tensor":
            if mutable and not local.use_const_ref_for_mutable_tensors():
                return NamedCType(binds, MutRefCType(
                    BaseCType(tensorT)))  # TODO: fix this discrepancy
            else:
                return NamedCType(
                    binds, ConstRefCType(OptionalCType(BaseCType(tensorT))))
        elif str(t.elem) == "Scalar":
            return NamedCType(binds,
                              ConstRefCType(OptionalCType(BaseCType(scalarT))))
        elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
            return NamedCType(binds, BaseCType(optionalIntArrayRefT))
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
        return NamedCType(binds, OptionalCType(elem.type))
    elif isinstance(t, ListType):
        # TODO: remove these special cases, ArrayRef fallthrough works fine
        if str(t.elem) == "int":
            if remove_non_owning_ref_types:
                return NamedCType(binds, VectorCType(BaseCType(longT)))
            else:
                return NamedCType(binds, BaseCType(intArrayRefT))
        elif str(t.elem) == "Tensor":
            return NamedCType(binds, BaseCType(tensorListT))
        elif str(t.elem) == "Scalar":
            return NamedCType(binds, ArrayRefCType(BaseCType(scalarT)))
        elif str(t.elem) == "SymInt":
            return NamedCType(binds, BaseCType(symIntArrayRefT))
        elif str(t.elem) == "Dimname":
            return NamedCType(binds, BaseCType(dimnameListT))
        elif str(t.elem) == "Tensor?":
            return NamedCType(
                binds,
                ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))))
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
        return NamedCType(binds, ArrayRefCType(elem.type))
    else:
        raise AssertionError(f"unrecognized type {repr(t)}")
Example #4
0
def isValueType(typ: CType,
                properties: "Optional[LazyIrProperties]" = None) -> bool:
    """
    Given a type, determine if it is a Value-like type.  This is equivalent to
    being Tensor-like, but assumes the type has already been transformed.
    """
    if isinstance(typ, BaseCType):
        # I am regretting my naming conventions, but now we are wrapping at::scalar in
        # lazy value, while preserving other 'scalar' types as scalars in the IR
        treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants
        return (typ.type == getValueT()
                or (typ.type == scalarT and not treat_scalars_as_constants)
                or typ.type == SymIntT)
    elif typ == VectorCType(BaseCType(SymIntT)):
        # TODO: report True for this
        return False
    elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
        return isValueType(typ.elem, properties)
    return False
Example #5
0
    def solve(goal: NamedCType, *, direct: bool) -> str:
        def direct_solve(goal: NamedCType) -> str:
            return solve(goal, direct=True)

        if goal in ctx:
            # Trivial
            return ctx[goal]

        # const & is satisfied with mutable &
        if isinstance(goal.type, ConstRefCType):
            try:
                # WARNING: not strictly decreasing; be careful not
                # to add a direct conversion that goes satisfies
                # mutable& with const&
                return solve(NamedCType(goal.name,
                                        MutRefCType(goal.type.elem)),
                             direct=direct)
            except UnsatError:
                pass

        # mutable & is satisfied with value
        if isinstance(goal.type, MutRefCType):
            try:
                return solve(NamedCType(goal.name, goal.type.elem),
                             direct=direct)
            except UnsatError:
                pass

        if direct:
            unsat(goal)

        # For now, all of these rules are mutually exclusive.
        if goal == NamedCType("memory_format",
                              OptionalCType(BaseCType(memoryFormatT))):
            memory_format = direct_solve(
                NamedCType(
                    SpecialArgName.possibly_redundant_memory_format,
                    OptionalCType(BaseCType(memoryFormatT)),
                ))
            # No need to join "memory_format" and "options" if the target API takes "options" directly.
            # Otherwise it will cause the redundant memory_format error.
            if options_ctype in goal_ctypes:
                return memory_format
            try:
                options = direct_solve(options_ctype)
                return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})"
            except UnsatError:
                return memory_format
        elif goal == NamedCType("options", BaseCType(tensorOptionsT)):
            dtype = direct_solve(
                NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))))
            pin_memory = direct_solve(
                NamedCType("pin_memory", OptionalCType(BaseCType(boolT))))
            device = direct_solve(
                NamedCType("device", OptionalCType(BaseCType(deviceT))))
            layout = direct_solve(
                NamedCType("layout", OptionalCType(BaseCType(layoutT))))
            return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})"

        elif goal == NamedCType("dtype",
                                OptionalCType(BaseCType(scalarTypeT))):
            try:
                options = direct_solve(options_ctype)
                return f"optTypeMetaToScalarType({options}.dtype_opt())"
            except UnsatError:
                out_tensor = direct_solve(out_tensor_ctype)
                return f"{out_tensor}.scalar_type()"

        elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))):
            try:
                options = direct_solve(options_ctype)
                return f"{options}.layout_opt()"
            except UnsatError:
                out_tensor = direct_solve(out_tensor_ctype)
                return f"{out_tensor}.layout()"

        elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))):
            try:
                options = direct_solve(options_ctype)
                return f"{options}.device_opt()"
            except UnsatError:
                out_tensor = direct_solve(out_tensor_ctype)
                return f"{out_tensor}.device()"

        elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))):
            try:
                options = direct_solve(options_ctype)
                return f"{options}.pinned_memory_opt()"
            except UnsatError:
                # If we're calling a factory op from its out= variant,
                # We don't actually care about the value of pin_memory.
                out_tensor = direct_solve(out_tensor_ctype)
                return "c10::nullopt"

        # We can always do translations from value types to reference types, like vector<int> -> IntArrayRef
        elif goal.type == BaseCType(intArrayRefT):
            try:
                return direct_solve(NamedCType(goal.name, longVec_ctype))
            except UnsatError:
                # We can also go SymIntArrayRef -> IntArrayRef
                symIntArrayRef_type = direct_solve(
                    NamedCType(goal.name, BaseCType(symIntArrayRefT)))
                return f"c10::asIntArrayRefSlow({symIntArrayRef_type})"
        elif goal.type == BaseCType(symIntArrayRefT):
            return direct_solve(NamedCType(goal.name, longSymVec_ctype))
        elif goal.type == BaseCType(longT):
            symInt_type = direct_solve(
                NamedCType(goal.name, BaseCType(SymIntT)))
            return f"{symInt_type}.expectInt()"
        elif goal.type == BaseCType(optionalIntArrayRefT):
            return direct_solve(NamedCType(goal.name, optionalLongVec_ctype))
        elif goal.type == BaseCType(optionalScalarRefT):
            return direct_solve(NamedCType(goal.name, optionalScalar_ctype))
        elif goal.type == BaseCType(optionalTensorRefT):
            return direct_solve(NamedCType(goal.name, optionalTensor_ctype))

        # Note [translation from C++ reference to value types]
        # The below cases are all for when we have an argument with a reference type,
        # and a corresponding goal with a value type.
        # These are needed when we populate the inputs to a lambda capture and we need
        # to guarantee the lifetime of each captured argument.
        # We guard it with an explicit kwarg because converting to a value type is expensive
        # (O(n)) to convert from IntArrayRef to vector<int>),
        # so the caller of translate() should be explicit that they need it.
        if allow_expensive_conversions:
            if goal.type == VectorCType(BaseCType(longT)):
                intArrayRef_ctype = NamedCType(goal.name,
                                               BaseCType(intArrayRefT))
                argname = direct_solve(intArrayRef_ctype)
                return f"{argname}.vec()"
            if goal.type == VectorCType(BaseCType(SymIntT)):
                symIntArrayRef_ctype = NamedCType(goal.name,
                                                  BaseCType(symIntArrayRefT))
                argname = direct_solve(symIntArrayRef_ctype)
                return f"{argname}.vec()"
            elif goal.type == OptionalCType(VectorCType(BaseCType(longT))):
                optionalIntArrayRef_ctype = NamedCType(
                    goal.name, BaseCType(optionalIntArrayRefT))
                argname = direct_solve(optionalIntArrayRef_ctype)
                return f"{argname}.has_value() ? c10::make_optional({argname}->vec()) : c10::nullopt"
            elif goal.type == OptionalCType(BaseCType(scalarT)):
                optionalScalarRef_ctype = NamedCType(
                    goal.name, BaseCType(optionalScalarRefT))
                argname = direct_solve(optionalScalarRef_ctype)
                return f"{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt"
            elif goal.type == OptionalCType(BaseCType(scalarT)):
                optionalTensorRef_ctype = NamedCType(
                    goal.name, BaseCType(optionalTensorRefT))
                argname = direct_solve(optionalTensorRef_ctype)
                return f"{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt"
            # Technically, we also need to handle cases of C++ containers holding reference types.
            # But there currently aren't any ops that require lambda capture codegen
            # With arguments like std::vector<IntArrayRef>.
            # If that changes, we'll have to add the translation here.

        # We allow const casting on tensors, since const-correctness is a bit broken for at::Tensor.
        # We could probably generalize this to non-tensor types too.
        if goal.type == MutRefCType(BaseCType(tensorT)):
            const_ref_tensor_ctype = NamedCType(
                goal.name, ConstRefCType(BaseCType(tensorT)))
            argname = direct_solve(const_ref_tensor_ctype)
            return f"const_cast<Tensor&>({argname})"

        unsat(goal)
Example #6
0
#     in the context, instead, "options" is, and you need to extract
#     it from there.  (Gather)
#
#   - Need the "context" binding?  Well, maybe "context" isn't available
#     in the context, and you need to construct it from "dtype", "device",
#     etc.  (Scatter)
#
#   - Need the "memory_format" binding?  Well, actually, it's available
#     from both "memory_format" and "options", so you had better make sure
#     they are consistent.  (Join)

options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT)))

out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT)))

longVec_ctype = VectorCType(BaseCType(longT))
longSymVec_ctype = VectorCType(BaseCType(SymIntT))
optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT)))
optionalScalar_ctype = OptionalCType(BaseCType(scalarT))
optionalTensor_ctype = OptionalCType(BaseCType(tensorT))


class UnsatError(RuntimeError):
    pass


# Given a set of in-scope bindings and a set of target bindings, synthesize
# a list of expressions that uses only the in-scope bindings (bindings) that
# have all of the types of goals.  You may want to use this function if
# you're generating code for a function like:
#
Example #7
0
def saved_variables(
    formula: str,
    nctypes: List[NamedCType],
    var_names: Tuple[str, ...],
) -> Tuple[str, Tuple[SavedAttribute, ...]]:
    def stride_expr(name: str) -> str:
        assert var_names == (name, ), (
            'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor '
            'that ".strides()" is being called on.')
        return f'strides_or_error({name}, "{name}")'

    REPLACEMENTS: List[Tuple[str, Dict[str, Any]]] = [
        # replace self.sizes() with self_sizes
        (
            r"{}.sizes\(\)",
            {
                "suffix": "_sizes",
                "nctype":
                lambda name: NamedCType(name, BaseCType(intArrayRefT)),
            },
        ),
        # replace self.sym_sizes() with self_sym_sizes
        (
            r"{}.sym_sizes\(\)",
            {
                "suffix":
                "_sym_sizes",
                "nctype":
                lambda name: NamedCType(name, BaseCType(symIntArrayRefT)),
            },
        ),
        # replace self->sizes() with self_sizes_opt
        (
            r"{}->sizes\(\)",
            {
                "suffix":
                "_sizes_opt",
                "nctype":
                lambda name: NamedCType(
                    name, OptionalCType(BaseCType(intArrayRefT))),
                "expr":
                lambda name:
                f"{name}.has_value() ? c10::optional<IntArrayRef>({name}->sizes()) : c10::nullopt",
            },
        ),
        # replace self.options() with self_options
        (
            r"{}.options\(\)",
            {
                "suffix": "_options",
                "nctype":
                lambda name: NamedCType(name, BaseCType(tensorOptionsT)),
            },
        ),
        # replace zeros_like(self) with self_info
        (
            r"zeros_like\({}\)",
            {
                "suffix": "_info",
                "nctype":
                lambda name: NamedCType(name, BaseCType(typeAndSizeT)),
                "expr": lambda name: name,  # at save-time
                "res": lambda name: name + "_info.zeros()",  # at eval-time
            },
        ),
        # replace self.size(2) with self_size_2
        (
            r"{}.size\((\w+)\)",
            {
                "suffix": lambda m: "_argsize_{}".format(*m.groups()),
                "nctype": lambda name: NamedCType(name, BaseCType(longT)),
            },
        ),
        # replace self.numel() with self_numel
        (
            r"{}.numel\(\)",
            {
                "suffix": "_numel",
                "nctype": lambda name: NamedCType(name, BaseCType(longT)),
            },
        ),
        # replace to_args_sizes(self) with self_args_sizes
        (
            r"to_args_sizes\({}\)",
            {
                "suffix":
                "_args_sizes",
                "nctype":
                lambda name: NamedCType(
                    name, VectorCType(VectorCType(BaseCType(longT)))),
            },
        ),
        # replace to_args_scalartypes(self) with self_args_scalartypes
        (
            r"to_args_scalartypes\({}\)",
            {
                "suffix":
                "_args_scalartypes",
                "nctype":
                lambda name: NamedCType(name,
                                        VectorCType(BaseCType(scalarTypeT))),
            },
        ),
        # replace TensorGeometry(self) with self_geometry
        (
            r"TensorGeometry\({}\)",
            {
                "suffix":
                "_geometry",
                "nctype":
                lambda name: NamedCType(name, BaseCType(tensorGeometryT)),
            },
        ),
        (
            r"{}.scalar_type\(\)",
            {
                "suffix": "_scalar_type",
                "nctype":
                lambda name: NamedCType(name, BaseCType(scalarTypeT)),
            },
        ),
        # replace self.dim() with self_dim
        (
            r"{}.dim\(\)",
            {
                "suffix": "_dim",
                "nctype": lambda name: NamedCType(name, BaseCType(longT)),
            },
        ),
        # replace self.strides() with self_strides
        (
            r"{}.strides\(\)",
            {
                "suffix": "_strides",
                "nctype":
                lambda name: NamedCType(name, BaseCType(intArrayRefT)),
                "expr": stride_expr,
            },
        ),
        # replace self.layout() with self_layout
        (
            r"{}.layout\(\)",
            {
                "suffix": "_layout",
                "nctype": lambda name: NamedCType(name, BaseCType(layoutT)),
            },
        ),
        # replace self.is_conj() with self_conjugate
        (
            r"{}.is_conj\(\)",
            {
                "suffix": "_conjugate",
                "nctype": lambda name: NamedCType(name, BaseCType(boolT)),
            },
        ),
    ]

    # find which arguments need to be saved
    saved: List[SavedAttribute] = []

    for nctype in nctypes:
        name = (nctype.name.name
                if isinstance(nctype.name, SpecialArgName) else nctype.name)
        # First search the formula for expressions which can be evaluated
        # when the autograd Function is created to avoid saving variables
        for regex, info in REPLACEMENTS:

            def repl(m: Match[str]) -> str:
                suffix: str = (info["suffix"](m)
                               if callable(info["suffix"]) else info["suffix"])
                expr: str = info["expr"](name) if "expr" in info else m.group(
                    0)
                saved.append(
                    SavedAttribute(
                        nctype=info["nctype"](name + suffix),
                        expr=expr,
                    ))
                if "res" in info:
                    replacement: str = info["res"](name)
                    return replacement
                return name + suffix

            formula = re.sub(regex.format(name), repl, formula)

        # c10::optional<std::string> types stored in Backward nodes must be
        # converted to c10::optional<c10::string_view> before being passed into
        # the backward function
        if nctype.type == OptionalCType(BaseCType(stringT)):
            formula = re.sub(
                rf"\b{name}\b",
                f"{name}.has_value() ? c10::optional<c10::string_view>({name}.value()) : c10::nullopt",
                formula,
            )

        # Find any variables which remain in the formula and save them
        if re.search(IDENT_REGEX.format(name), formula):
            saved.append(SavedAttribute(
                nctype=nctype,
                expr=name,
            ))

    return formula, tuple(saved)
Example #8
0
def get_owning_type(t: CType) -> Tuple[CType, Callable[[str], str]]:
    if t == BaseCType(tensorListT):
        return VectorCType(BaseCType(tensorT)), lambda x: f"{x}.vec()"
    # There are technically other non-owning types out there (like IntArrayRef),
    # but functionalization only actually cares about the ones involving tensors.
    return t, lambda x: x
Example #9
0
#
#   - Need the "dtype" binding?  Well, maybe "dtype" isn't available
#     in the context, instead, "options" is, and you need to extract
#     it from there.  (Gather)
#
#   - Need the "context" binding?  Well, maybe "context" isn't available
#     in the context, and you need to construct it from "dtype", "device",
#     etc.  (Scatter)
#
#   - Need the "memory_format" binding?  Well, actually, it's available
#     from both "memory_format" and "options", so you had better make sure
#     they are consistent.  (Join)

options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT)))

longVec_ctype = VectorCType(BaseCType(longT))
optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT)))
optionalScalar_ctype = OptionalCType(BaseCType(scalarT))
optionalTensor_ctype = OptionalCType(BaseCType(tensorT))


class UnsatError(RuntimeError):
    pass


# Given a set of in-scope bindings and a set of target bindings, synthesize
# a list of expressions that uses only the in-scope bindings (bindings) that
# have all of the types of goals.  You may want to use this function if
# you're generating code for a function like:
#
#   void f({args}) {