Пример #1
0
def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments], *,
             is_out: bool) -> List[Binding]:
    # Ideally, we NEVER default native functions.  However, there are a number
    # of functions that call native:: directly and rely on the defaulting
    # existing.  So for BC, we generate defaults for non-out variants (but not
    # for out variants, where it is impossible to generate an appropriate
    # default)
    should_default = not is_out
    if isinstance(a, Argument):
        default: Optional[str] = None
        if should_default and a.default is not None:
            default = cpp.default_expr(a.default, a.type)
        return [
            Binding(
                nctype=argument_type(a, binds=a.name),
                name=a.name,
                default=default,
                argument=a,
            )
        ]
    elif isinstance(a, SelfArgument):
        # Erase SelfArgument from the distinction
        return argument(a.argument, is_out=is_out)
    elif isinstance(a, TensorOptionsArguments):
        default = None
        if should_default:
            default = '{}'
        # TODO: Not sure why the arguments assigned here are for
        # TensorOptionsArguments and not the constituent pieces.  It seems
        # to matter
        return [
            Binding(
                nctype=NamedCType('dtype',
                                  OptionalCType(BaseCType(scalarTypeT))),
                name='dtype',
                default=default,
                argument=a,
            ),
            Binding(
                nctype=NamedCType('layout', OptionalCType(BaseCType(layoutT))),
                name='layout',
                default=default,
                argument=a,
            ),
            Binding(
                nctype=NamedCType('device', OptionalCType(BaseCType(deviceT))),
                name='device',
                default=default,
                argument=a,
            ),
            Binding(
                nctype=NamedCType('pin_memory',
                                  OptionalCType(BaseCType(boolT))),
                name='pin_memory',
                default=default,
                argument=a,
            )
        ]
    else:
        assert_never(a)
Пример #2
0
def process_ir_type(
        typ: Type) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]:
    """
    This function takes a type from NativeFunctions and converts it for use with
    lazy tensor codegen.

    Type conversion for lazy currently consists of
     (1) changing at::Tensors into lazy::Values
     (2) wrapping everything in a BaseCType
     (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef)

    (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.)
    There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like'

    This is incomplete- there are assertions in places that it's expected to need to add
    more types as the codegen is used with more operators.
    """
    if isinstance(typ, BaseType):
        if typ.name == BaseTy.Tensor:
            return BaseCType(valueT)
        elif typ.name == BaseTy.Scalar:
            # at::scalar has special handling,
            # and is wrapped in an lazy::Value just like at::tensor
            return BaseCType(valueT)
        elif typ.name == BaseTy.ScalarType:
            return BaseCType(scalarTypeT)
        elif typ.name == BaseTy.int:
            return BaseCType(longT)
        elif typ.name == BaseTy.bool:
            return BaseCType(boolT)
        elif typ.name == BaseTy.float:
            return BaseCType(doubleT)
        elif typ.name == BaseTy.str:
            return BaseCType(stringT)
        elif typ.name == BaseTy.Device:
            return BaseCType(deviceT)
        elif typ.name == BaseTy.Layout:
            return BaseCType(layoutT)
        elif typ.name == BaseTy.MemoryFormat:
            return BaseCType(memoryFormatT)
        else:
            raise AssertionError(f"TODO add support for type {repr(typ)}")
    elif isinstance(typ, OptionalType):
        return OptionalCType(process_ir_type(typ.elem))
    elif isinstance(typ, ListType):
        if str(typ.elem) == 'Tensor?':
            # TODO(whc) is this actually correct? or should it use a Vector like above
            return ListCType(OptionalCType(BaseCType(valueT)))
        elif str(typ.elem) == 'Tensor':
            # this is a TensorList which comes in from GetTensorList as a Value
            return BaseCType(tensorListValueT)
        else:
            return VectorCType(process_ir_type(typ.elem))
    else:
        raise AssertionError(f"unrecognized type {repr(typ)}")
def emit_view_lambda(f: NativeFunction,
                     unpacked_bindings: List[Binding]) -> str:
    """ Generate an additional lambda function to recover views in backward when as_strided is not supported.
    See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details."""
    input_base = 'input_base'
    replay_view_func = ''
    updated_unpacked_args: List[str] = []
    known_view_arg_simple_types: List[CType] = [
        BaseCType(intT),
        OptionalCType(BaseCType(intT)),
        BaseCType(boolT),
        BaseCType(intArrayRefT)
    ]
    for unpacked_binding in unpacked_bindings:
        arg, arg_type = unpacked_binding.name, unpacked_binding.nctype.type
        if arg == 'self_':
            updated_unpacked_args.append(input_base)
            continue
        if arg_type not in known_view_arg_simple_types:
            known_types_str = ', '.join(
                [str(t) for t in known_view_arg_simple_types])
            raise TypeError(
                f'You are adding an {arg_type} {arg} argument to op {cpp.name(f.func)} in addition to known types: '
                f'{known_types_str}. Please update the list or materialize it so that it can be closed '
                'over by value, also add a test in pytorch/xla/test/test_operations.py where this code '
                'is exercised.')

        if arg_type == BaseCType(intArrayRefT):
            # It's not safe to close over IntArrayRef by value, since this is a
            # reference type, so materialize a vector to close over by value
            arg_vec = arg + '_vec'
            replay_view_func += ARRAYREF_TO_VEC.substitute(arg=arg,
                                                           vec=arg_vec)
            updated_unpacked_args.append(arg_vec)
        elif arg_type == OptionalCType(BaseCType(intT)):
            # Materialize int64_t? to int64_t
            arg_value = arg + '_val'
            replay_view_func += OPTIONAL_TO_VAL.substitute(arg=arg,
                                                           val=arg_value,
                                                           default='0')
            updated_unpacked_args.append(arg_value)
        else:
            updated_unpacked_args.append(arg)

    replay_view_call = emit_view_call(f, input_base, updated_unpacked_args)
    replay_view_func += REPLAY_VIEW_LAMBDA_FUNC.substitute(
        input_base=input_base, replay_view_call=replay_view_call)

    is_view_with_metadata_change = 'true' if cpp.name(
        f.func) in VIEW_FUNCTIONS_WITH_METADATA_CHANGE else 'false'

    return SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE.substitute(
        is_view_with_metadata_change=is_view_with_metadata_change,
        replay_view_func=replay_view_func)
Пример #4
0
    def emit_fw_derivatives() -> List[str]:
        content: List[str] = []
        fw_grad_setters: List[str] = []
        for derivative in fw_derivatives:
            res = derivative.var_name
            if f.func.name.name.inplace:
                # TODO update this when inplace namings are unified
                res = "self"

            assert derivative.required_inputs_fw_grad is not None

            unpacked_arguments = ""
            for inp in differentiable_inputs:
                zeros_fn = "zeros" if inplace and inp.name == "self" else "_efficientzerotensor"
                if inp.name in derivative.required_inputs_fw_grad:
                    unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(inp=inp.name, zeros_fn=zeros_fn)
                if inp.name in (derivative.required_inputs_primal or []):
                    unpacked_arguments += FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(inp=inp.name)
            if derivative.required_original_self_value:
                unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(inp="original_self", zeros_fn=zeros_fn)
                unpacked_arguments += FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(inp="original_self")
            elif inplace and derivative.is_reusing_outplace_formula:
                # The gradient wasn't already cloned, do it if grad mode is enabled
                unpacked_arguments += "self_t = GradMode::is_enabled() ? self_t.clone() : self_t;"

            if inplace:
                is_inplace_str = "true"
            else:
                is_inplace_str = "false"

            if isinstance(derivative.var_type, BaseType) and derivative.var_type.is_tensor_like():
                # Is there a way to get from BaseType to BaseCType
                opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type()
                fw_grad_setter = FW_DERIVATIVE_SETTER_TENSOR.substitute(out_arg=res, is_inplace=is_inplace_str)
            elif isinstance(derivative.var_type, ListType) and derivative.var_type.is_tensor_like():
                opt_res_grad_type = OptionalCType(VectorCType(BaseCType(tensorT))).cpp_type()
                fw_grad_setter = FW_DERIVATIVE_SETTER_TENSOR_LIST.substitute(out_arg=res, is_inplace=is_inplace_str)
            else:
                raise RuntimeError("Unsupported output type for forward derivative")

            fw_grad_opt_definition = f"{opt_res_grad_type} {res}_new_fw_grad_opt = c10::nullopt;"

            # View ops create fw_grad that already is a view of the base's fw_grad so just use that
            content.append(FW_DERIVATIVE_TEMPLATE.substitute(
                fw_grad_opt_definition=fw_grad_opt_definition,
                requires_fw_grad=get_any_has_forward_grad_name(derivative.var_name), formula=derivative.formula, out_arg=res,
                unpacked_arguments=unpacked_arguments))
            fw_grad_setters.append(fw_grad_setter)

        # Set all the grads at the end to avoid: https://github.com/pytorch/pytorch/issues/67367
        content.append('\n'.join(fw_grad_setters))
        return content
Пример #5
0
def process_ir_type(typ: Type) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]:
    """
    This function takes a type from NativeFunctions and converts it for use with
    lazy tensor codegen.  Currently its output is used in several places, and so far
    it has been possible for them to all use the same conversions, but that may not be
    optimal or possible in the finished system.

    Type conversion for lazy currently consists of
     (1) changing Tensor-like things into Value-like things
     (2) wrapping everything in a BaseCType
     (3) making reference types into values (e.g. vector instead of IntArrayRef)

    (1) converts Tensors to Values since Values are how Lazy IR represents tensors.  There
    is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like'

    This is incomplete- there are assertions in places that it's expected to need to add
    more types as the codegen is used with more operators.
    """
    if isinstance(typ, BaseType):
        if typ.name == BaseTy.Tensor:
            return BaseCType(valueT)
        elif typ.name == BaseTy.Scalar:
            # at::scalar has special handling,
            # and is wrapped in an IR value just like at::tensor
            return BaseCType(valueT)
        elif typ.name == BaseTy.ScalarType:
            return BaseCType(scalarTypeT)
        elif typ.name == BaseTy.int:
            return BaseCType(longT)
        elif typ.name == BaseTy.bool:
            return BaseCType(boolT)
        elif typ.name == BaseTy.float:
            return BaseCType(doubleT)
        elif typ.name == BaseTy.str:
            return BaseCType(stringT)
        elif typ.name == BaseTy.Device:
            return BaseCType(deviceT)
        elif typ.name == BaseTy.Layout:
            return BaseCType(layoutT)
        else:
            raise AssertionError(f"TODO add support for type {repr(typ)}")
    elif isinstance(typ, OptionalType):
        return OptionalCType(process_ir_type(typ.elem))
    elif isinstance(typ, ListType):
        if str(typ.elem) == 'Tensor?':
            # TODO(whc) is this actually correct? or should it use a Vector like above
            return ListCType(OptionalCType(BaseCType(valueT)))
        else:
            return VectorCType(process_ir_type(typ.elem))
    else:
        raise AssertionError(f"unrecognized type {repr(typ)}")
Пример #6
0
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType:
    if str(t) == 'Tensor?':
        tensor_type: OptionalCType = OptionalCType(BaseCType('Tensor', binds))
        if mutable:
            return MutRefCType(tensor_type)
        else:
            return ConstRefCType(tensor_type)
    elif str(t) == 'Tensor?[]':
        return ConstRefCType(BaseCType("c10::List<c10::optional<Tensor>>", binds))
    elif str(t) == 'Scalar':
        return ConstRefCType(BaseCType('Scalar', binds))
    elif str(t) == 'Scalar?':
        return ConstRefCType(OptionalCType(BaseCType('Scalar', binds)))
    return cpp.argumenttype_type(t, mutable=mutable, binds=binds)
Пример #7
0
 def save_variables(
     saved_variables: Sequence[SavedAttribute],
     is_output: bool,
     guard_for: Callable[[SavedAttribute],
                         Optional[str]] = lambda name: None,
 ) -> Sequence[str]:
     # assign the saved variables to the generated grad_fn
     stmts: List[str] = []
     for arg in saved_variables:
         name = arg.nctype.name.name if isinstance(
             arg.nctype.name, SpecialArgName) else arg.nctype.name
         type = arg.nctype.type
         expr = arg.expr
         stmts_prepend = None
         if type == BaseCType(tensorT) or type == OptionalCType(BaseCType(tensorT)) or \
                 type == MutRefCType(OptionalCType(BaseCType(tensorT))) or (is_output and type == BaseCType(scalarT)):
             var = name
             name += '_'
             if var == 'self' and inplace:
                 stmts_prepend = 'if (!original_self.has_value()) original_self = self.clone()'
                 var = 'original_self.value()'
                 assert not is_output
             if inplace and is_output:
                 var = 'self'
                 is_inplace_view = f'{var}.is_view()'
                 expr = f'SavedVariable({var}, {str(is_output).lower()}, {is_inplace_view})'
             else:
                 expr = f'SavedVariable({var}, {str(is_output).lower()})'
         elif type == BaseCType(tensorListT) or type == ListCType(
                 OptionalCType(BaseCType(tensorT))):
             expr = f'make_saved_variable_list({name})'
             name += '_'
         elif type == BaseCType(intArrayRefT):
             expr = expr + ".vec()"
         elif type == BaseCType(stringT):
             expr = f'std::string({expr})'
         elif type == OptionalCType(BaseCType(stringT)):
             expr = f'{expr}.has_value() ? c10::optional<std::string>(std::string({expr}.value())) : c10::nullopt'
         guard = guard_for(arg)
         if guard is None:
             if stmts_prepend:
                 stmts.append(f'{stmts_prepend};')
             stmts.append(f'grad_fn->{name} = {expr};')
         else:
             stmts.append(f'if ({guard}) {{')
             if stmts_prepend:
                 stmts.append(f'  {stmts_prepend};')
             stmts.append(f'  grad_fn->{name} = {expr};')
             stmts.append('}')
     return stmts
Пример #8
0
def valuetype_type(t: Type, *, binds: ArgName) -> Optional[CType]:
    if isinstance(t, BaseType):
        if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
            return None
        elif t.name == BaseTy.int:
            return BaseCType('int64_t', binds)
        elif t.name == BaseTy.float:
            return BaseCType('double', binds)
        elif t.name == BaseTy.str:
            return BaseCType('std::string', binds)
        elif t.name in [
                BaseTy.bool, BaseTy.QScheme, BaseTy.Scalar, BaseTy.ScalarType,
                BaseTy.Generator, BaseTy.Storage, BaseTy.Layout, BaseTy.Device,
                BaseTy.MemoryFormat, BaseTy.Dimname, BaseTy.Stream,
                BaseTy.ConstQuantizerPtr
        ]:
            # These C++ names line up with their schema names
            return BaseCType(t.name.name, binds)
        else:
            raise AssertionError(f"unsupported type: {t}")
    elif isinstance(t, OptionalType):
        elem = valuetype_type(t.elem, binds=binds)
        if elem is None:
            return None
        return OptionalCType(elem)
    elif isinstance(t, ListType):
        if str(t.elem) == 'bool':
            assert t.size is not None
            return BaseCType(f"std::array<bool,{t.size}>", binds)
        else:
            return None
    else:
        raise AssertionError(f"unrecognized type {repr(t)}")
Пример #9
0
 def enforce_same_tensorimpl_and_storage(call: str, unpacked_bindings: List[Binding]) -> str:
     save_ptrs_stmts: List[str] = []
     enforce_same_ptrs_stmts: List[str] = []
     if cpp.name(f.func) not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
         for unpacked_binding in unpacked_bindings:
             arg = unpacked_binding.name
             noref_cpp_type = unpacked_binding.nctype.type.remove_const_ref()
             if noref_cpp_type == BaseCType(tensorListT):
                 save_ptrs_stmts += [SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
                                     SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg)]
                 enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
                                             ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg)]
             elif noref_cpp_type == ListCType(OptionalCType(BaseCType(tensorT))):
                 save_ptrs_stmts += [SAVE_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg),
                                     SAVE_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg)]
                 enforce_same_ptrs_stmts += [ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg),
                                             ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg)]
             elif noref_cpp_type == BaseCType(tensorT):
                 save_ptrs_stmts += [SAVE_TENSOR_STORAGE.substitute(tensor_name=arg),
                                     SAVE_TENSOR_IMPL.substitute(tensor_name=arg)]
                 enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSOR_STORAGE.substitute(tensor_name=arg),
                                             ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg)]
     assert (save_ptrs_stmts and enforce_same_ptrs_stmts) or (not save_ptrs_stmts and not enforce_same_ptrs_stmts)
     if save_ptrs_stmts and enforce_same_ptrs_stmts:
         call = RUN_ONLY_IN_DEBUG_MODE.substitute(statements=save_ptrs_stmts) + \
             call + \
             RUN_ONLY_IN_DEBUG_MODE.substitute(statements=enforce_same_ptrs_stmts)
     return call
Пример #10
0
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
    if str(t) == 'Tensor?':
        tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT))
        if mutable and not local.use_const_ref_for_mutable_tensors():
            return NamedCType(binds, MutRefCType(tensor_type))
        else:
            return NamedCType(binds, ConstRefCType(tensor_type))
    elif str(t) == 'Tensor?[]':
        return NamedCType(
            binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))))
    elif str(t) == 'Scalar':
        return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
    elif str(t) == 'Scalar?':
        return NamedCType(binds,
                          ConstRefCType(OptionalCType(BaseCType(scalarT))))
    return cpp.argumenttype_type(t, mutable=mutable, binds=binds)
Пример #11
0
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
    # If it's a value type, do the value type translation
    r = valuetype_type(t, binds=binds)
    if r is not None:
        return r

    if isinstance(t, BaseType):
        if t.name == BaseTy.Tensor:
            if mutable and not local.use_const_ref_for_mutable_tensors():
                return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
            else:
                return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
        elif t.name == BaseTy.Scalar:
            return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
        else:
            raise AssertionError(f"base type should have been value type {t}")
    elif isinstance(t, OptionalType):
        if str(t.elem) == 'Tensor':
            if mutable and not local.use_const_ref_for_mutable_tensors():
                return NamedCType(binds, MutRefCType(
                    BaseCType(tensorT)))  # TODO: fix this discrepancy
            else:
                return NamedCType(
                    binds, ConstRefCType(OptionalCType(BaseCType(tensorT))))
        elif str(t.elem) == 'Scalar':
            return NamedCType(binds,
                              ConstRefCType(OptionalCType(BaseCType(scalarT))))
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
        return NamedCType(binds, OptionalCType(elem.type))
    elif isinstance(t, ListType):
        # TODO: remove these special cases, ArrayRef fallthrough works fine
        if str(t.elem) == 'int':
            return NamedCType(binds, BaseCType(intArrayRefT))
        elif str(t.elem) == 'Tensor':
            return NamedCType(binds, BaseCType(tensorListT))
        elif str(t.elem) == 'Scalar':
            return NamedCType(binds, ArrayRefCType(BaseCType(scalarT)))
        elif str(t.elem) == 'Dimname':
            return NamedCType(binds, BaseCType(dimnameListT))
        elif str(t.elem) == 'Tensor?':
            return NamedCType(
                binds,
                ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))))
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
        return NamedCType(binds, ArrayRefCType(elem.type))
    else:
        raise AssertionError(f"unrecognized type {repr(t)}")
Пример #12
0
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType:
    # If it's a value type, do the value type translation
    r = valuetype_type(t, binds=binds)
    if r is not None:
        return r

    if isinstance(t, BaseType):
        if t.name == BaseTy.Tensor:
            if mutable:
                return MutRefCType(BaseCType('Tensor', binds))
            else:
                return ConstRefCType(BaseCType('Tensor', binds))
        elif t.name == BaseTy.Scalar:
            return ConstRefCType(BaseCType('Scalar', binds))
        else:
            raise AssertionError(f"base type should have been value type {t}")
    elif isinstance(t, OptionalType):
        if str(t.elem) == 'Tensor':
            if mutable:
                return MutRefCType(BaseCType(
                    'Tensor', binds))  # TODO: fix this discrepancy
            else:
                return ConstRefCType(OptionalCType(BaseCType('Tensor', binds)))
        elif str(t.elem) == 'Scalar':
            return ConstRefCType(OptionalCType(BaseCType('Scalar', binds)))
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
        return OptionalCType(elem)
    elif isinstance(t, ListType):
        # TODO: remove these special cases, ArrayRef fallthrough works fine
        # NB: CType throws away ArrayRef structure because it is not currently
        # relevant in translation.  When it becomes relevant, need to add back
        if str(t.elem) == 'int':
            return BaseCType("IntArrayRef", binds)
        elif str(t.elem) == 'Tensor':
            return BaseCType("TensorList", binds)
        elif str(t.elem) == 'Scalar':
            return BaseCType("ArrayRef<Scalar>", binds)
        elif str(t.elem) == 'Dimname':
            return BaseCType("DimnameList", binds)
        elif str(t.elem) == 'Tensor?':
            return ConstRefCType(
                BaseCType("c10::List<c10::optional<Tensor>>", binds))
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
        # TODO: explicitly qualify namespace here
        return BaseCType(f"ArrayRef<{elem.cpp_type()}>", binds)
    else:
        raise AssertionError(f"unrecognized type {repr(t)}")
Пример #13
0
def valuetype_type(t: Type, *, binds: ArgName) -> Optional[NamedCType]:
    if isinstance(t, BaseType):
        if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
            return None
        # All other BaseType currently map directly to BaseCppTypes.
        return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
    elif isinstance(t, OptionalType):
        elem = valuetype_type(t.elem, binds=binds)
        if elem is None:
            return None
        return NamedCType(binds, OptionalCType(elem.type))
    elif isinstance(t, ListType):
        if str(t.elem) == 'bool':
            assert t.size is not None
            return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size))
        else:
            return None
    else:
        raise AssertionError(f"unrecognized type {repr(t)}")
Пример #14
0
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
    # If it's a value type, do the value type translation
    r = cpp.valuetype_type(t, binds=binds)
    if r is not None:
        return r

    if isinstance(t, BaseType):
        if t.name == BaseTy.Tensor:
            return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
        elif t.name == BaseTy.Scalar:
            return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
        else:
            raise AssertionError(f"base type should have been value type {t}")
    elif isinstance(t, OptionalType):
        if t.elem == BaseType(BaseTy.Tensor):
            raise AssertionError(
                "optional tensor not supported by structured yet; to implement this "
                "add OptionalTensor c.f. https://github.com/pytorch/pytorch/issues/51456"
            )
        elif t.elem == BaseType(BaseTy.Scalar):
            raise AssertionError(
                "optional scalar not supported by structured yet"
            )
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
        return NamedCType(binds, OptionalCType(elem.type))
    elif isinstance(t, ListType):
        if t.elem == BaseType(BaseTy.Tensor):
            raise AssertionError(
                "list of tensor not supported by structured yet; to implement this "
                "resolve torch::List issue, see "
                "https://fb.workplace.com/groups/894363187646754/permalink/1149276442155426"
            )
        # TODO: delete these special cases; see tools.codegen.api.cpp--these
        # must be changed in tandem, but there are problems; see
        # https://github.com/pytorch/pytorch/pull/51485
        elif str(t.elem) == 'int':
            return NamedCType(binds, BaseCType(intArrayRefT))
        elif str(t.elem) == 'Dimname':
            return NamedCType(binds, BaseCType(dimnameListT))
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
        return NamedCType(binds, ArrayRefCType(elem.type))
    else:
        raise AssertionError(f"unrecognized type {repr(t)}")
Пример #15
0
def valuetype_type(t: Type, *, binds: ArgName, remove_non_owning_ref_types: bool = False) -> Optional[NamedCType]:
    if isinstance(t, BaseType):
        if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
            return None
        if remove_non_owning_ref_types:
            if t.name == BaseTy.str:
                raise AssertionError("string ref->value conversion: not implemented yet")
        # All other BaseType currently map directly to BaseCppTypes.
        return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
    elif isinstance(t, OptionalType):
        elem = valuetype_type(t.elem, binds=binds)
        if elem is None:
            return None
        return NamedCType(binds, OptionalCType(elem.type))
    elif isinstance(t, ListType):
        if str(t.elem) == 'bool':
            assert t.size is not None
            return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size))
        else:
            return None
    else:
        raise AssertionError(f"unrecognized type {repr(t)}")
Пример #16
0
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
    # If it's a value type, do the value type translation
    r = cpp.valuetype_type(t, binds=binds)
    if r is not None:
        return r

    if isinstance(t, BaseType):
        if t.name == BaseTy.Tensor:
            return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
        elif t.name == BaseTy.Scalar:
            return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
        else:
            raise AssertionError(f"base type should have been value type {t}")
    elif isinstance(t, OptionalType):
        if t.elem == BaseType(BaseTy.Tensor):
            return NamedCType(binds, BaseCType(optionalTensorRefT))
        elif t.elem == BaseType(BaseTy.Scalar):
            return NamedCType(binds, BaseCType(optionalScalarRefT))
        elif isinstance(t.elem, ListType) and str(t.elem.elem) == 'int':
            return NamedCType(binds, BaseCType(optionalIntArrayRefT))
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
        return NamedCType(binds, OptionalCType(elem.type))
    elif isinstance(t, ListType):
        if t.elem == BaseType(BaseTy.Tensor):
            return NamedCType(binds, BaseCType(iTensorListRefT))
        # TODO: delete these special cases; see tools.codegen.api.cpp--these
        # must be changed in tandem, but there are problems; see
        # https://github.com/pytorch/pytorch/pull/51485
        elif str(t.elem) == 'int':
            return NamedCType(binds, BaseCType(intArrayRefT))
        elif str(t.elem) == 'Dimname':
            return NamedCType(binds, BaseCType(dimnameListT))
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
        return NamedCType(binds, ArrayRefCType(elem.type))
    else:
        raise AssertionError(f"unrecognized type {repr(t)}")
    def save_var(var: SavedAttribute, is_output: bool) -> None:
        name = var.nctype.name
        type = var.nctype.type
        should_append_getsetdef = True
        should_append_raw_getsetdef = False

        if type == BaseCType(tensorT) or type == OptionalCType(BaseCType(tensorT)) or \
                type == MutRefCType(OptionalCType(BaseCType(tensorT))) or \
                (type == BaseCType(scalarT) and is_output):
            saved_variables.append(f'SavedVariable {name}_;')
            release_variables.append(f'{name}_.reset_data();')
            ptr = 'shared_from_this()' if is_output else ''
            unpack.append(f'auto {name} = {name}_.unpack({ptr});')
            getter_definitions.append(
                GETTER_DEFINITION_SAVEDVAR.substitute(
                    op=info.op, name=name, body=GETTER_BODY_SAVEDVAR))
            getter_definitions.append(
                GETTER_DEFINITION_RAW_SAVEDVAR.substitute(
                    op=info.op, name=name, body=GETTER_BODY_RAW_SAVEDVAR))
            should_append_raw_getsetdef = True
        elif type == BaseCType(tensorListT):
            saved_variables.append(f'std::vector<SavedVariable> {name}_;')
            saved_variables.append(f'bool {name}_released_ = false;')
            # Just clear() is sufficient, we don't need to loop and clear each variable.
            # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
            release_variables.append(f'{name}_.clear();')
            release_variables.append(f'{name}_released_ = true;')
            unpack.append(f'auto {name} = unpack_list({name}_);')
            asserts.append(
                f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);')
            getter_definitions.append(
                GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
                    op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR))
            getter_definitions.append(
                GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute(
                    op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR))
            should_append_raw_getsetdef = True
        elif type == ListCType(OptionalCType(BaseCType(tensorT))):
            saved_variables.append(f'std::vector<SavedVariable> {name}_;')
            saved_variables.append(f'bool {name}_released_ = false;')
            # Just clear() is sufficient, we don't need to loop and clear each variable.
            # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
            release_variables.append(f'{name}_.clear();')
            release_variables.append(f'{name}_released_ = true;')
            unpack.append(f'auto {name} = unpack_opt_list({name}_);')
            asserts.append(
                f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);')
            getter_definitions.append(
                GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
                    op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR))
            getter_definitions.append(
                GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute(
                    op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR))
            should_append_raw_getsetdef = True
        elif type == BaseCType(intArrayRefT):
            saved_variables.append(f'std::vector<int64_t> {name};')
            getter_definitions.append(
                GETTER_DEFINITION.substitute(op=info.op,
                                             name=name,
                                             body=GETTER_BODY_ARRAYREF_LONG))
        elif type == OptionalCType(BaseCType(intArrayRefT)):
            saved_variables.append(f'c10::OptionalArray<int64_t> {name};')
            getter_definitions.append(
                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG))
        elif type == OptionalCType(ArrayRefCType(BaseCType(doubleT))):
            saved_variables.append(f'c10::OptionalArray<double> {name};')
            getter_definitions.append(
                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_DOUBLE))
        elif type == BaseCType(intT):
            saved_variables.append(f'{type.cpp_type()} {name} = 0;')
            getter_definitions.append(
                GETTER_DEFINITION.substitute(op=info.op,
                                             name=name,
                                             body=GETTER_BODY_INT64_T))
        elif type == BaseCType(stringT):
            saved_variables.append(f'std::string {name};')
            getter_definitions.append(
                GETTER_DEFINITION.substitute(op=info.op,
                                             name=name,
                                             body=GETTER_BODY_STRING))
        elif type == OptionalCType(BaseCType(stringT)):
            saved_variables.append(f'c10::optional<std::string> {name};')
            getter_definitions.append(
                GETTER_DEFINITION_OPT.substitute(op=info.op,
                                                 name=name,
                                                 body=GETTER_BODY_STRING))
        else:
            saved_variables.append(f'{type.cpp_type()} {name};')

            if type in MISC_GETTER_DEFS:
                getter_def, body = MISC_GETTER_DEFS[type]
                getter_definitions.append(
                    getter_def.substitute(op=info.op, name=name, body=body))
            else:
                # Types we don't expose python bindings to yet:
                #   TypeAndSize, at::ScalarType, TensorOptions, TensorGeometry,
                #   std::vector<std::vector<int64_t>>, std::vector<at::ScalarType>
                should_append_getsetdef = False

        if should_append_getsetdef:
            py_getsetdef_structs.append(
                PY_GETSETDEF_STRUCT.substitute(op=info.op, name=name))
        if should_append_raw_getsetdef:
            py_getsetdef_structs.append(
                PY_RAW_GETSETDEF_STRUCT.substitute(op=info.op, name=name))
} else if (prop.isIntegral(/*includeBool=*/false)) {
  return PyLong_FromLong(prop.to<int64_t>());
} else if (prop.isBoolean()) {
  if (prop.to<bool>()) {
    Py_RETURN_TRUE;
  } else {
    Py_RETURN_FALSE;
  }
} else {
  PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type");
  return nullptr;
}
"""

MISC_GETTER_DEFS = {
    OptionalCType(BaseCType(intT)):
    (GETTER_DEFINITION_OPT, GETTER_BODY_INT64_T),
    BaseCType(doubleT): (GETTER_DEFINITION, GETTER_BODY_DOUBLE),
    OptionalCType(BaseCType(doubleT)):
    (GETTER_DEFINITION_OPT, GETTER_BODY_DOUBLE),
    BaseCType(boolT): (GETTER_DEFINITION, GETTER_BODY_BOOL),
    BaseCType(scalarT): (GETTER_DEFINITION, GETTER_BODY_SCALAR),
    OptionalCType(BaseCType(scalarT)):
    (GETTER_DEFINITION_OPT, GETTER_BODY_SCALAR),
}

# These functions have backwards which cannot be traced, and so must have
# their backward functions traced opaquely.
# VIEW_FUNCTIONS are not traceable because they use as_strided, which
# has an untraceable backwards, see
# https://github.com/pytorch/pytorch/issues/4250
Пример #19
0
def translate(bindings: Sequence[Union[Expr, Binding]],
              goals: Sequence[Union[NamedCType, Binding]],
              *,
              method: bool = False) -> List[Expr]:

    binding_exprs: List[Expr] = []
    for b in bindings:
        if isinstance(b, Binding):
            binding_exprs.append(Expr(
                expr=b.name,
                type=b.nctype,
            ))
        else:
            binding_exprs.append(b)

    goal_ctypes: List[NamedCType] = []
    for g in goals:
        if isinstance(g, Binding):
            goal_ctypes.append(g.nctype)
        else:
            goal_ctypes.append(g)

    # Add all the bindings to the context
    ctx: Dict[NamedCType, str] = {}
    for b in binding_exprs:
        ctx[b.type] = b.expr

        # While we're at it, do some simple forward inference, looking through
        # constructors.
        # TODO: My kingdom for a pattern matcher
        # https://www.python.org/dev/peps/pep-0634/
        # TODO: This could get us in recomputation trouble if b.expr is nontrivial
        t = b.type
        if isinstance(t, ConstRefCType) and isinstance(t.elem, OptionalCType) and \
                isinstance(t.elem.elem, BaseCType) and str(t.elem.elem.type) == 'at::Tensor':
            ctx[NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))] = \
                f'({b.expr}.has_value() ? *{b.expr} : at::Tensor())'

        if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))):
            ctx[NamedCType(t.name, BaseCType(optionalTensorRefT))] = \
                f'(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())'

        if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))):
            ctx[NamedCType(t.name, BaseCType(optionalScalarRefT))] = \
                f'({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())'

    # Add implicit bindings if the generated code is inside a Tensor method
    if method:
        ctx[NamedCType("self", MutRefCType(
            BaseCType(tensorT)))] = "const_cast<Tensor&>(*this)"
        ctx[NamedCType("self", ConstRefCType(
            BaseCType(tensorT)))] = "const_cast<Tensor&>(*this)"
        # This is better!  Byte-for-byte compat
        # ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this"

    def unsat(goal: NamedCType) -> NoReturn:
        ctx_desc = '\n'.join(f"  {t.cpp_type()} {t.name}; // {e}"
                             for t, e in ctx.items())
        raise UnsatError(f'''
Failed to synthesize the expression "{goal.cpp_type()} {goal.name}".
When I failed, the following bindings were available in the context:

{ctx_desc}

This probably means there is a missing rule in the rules of tools.codegen.api.translate.
Check this module for more information.
''')

    # A shitty backtracking search implementation.  It's shitty because it
    # doesn't actually do backtracing or search. In particular, if
    # direct=True, we won't try to do any fancy synthesis, just trivial
    # conversions (e.g., "T a" is OK for "const T& a").  So all of the
    # existing rules in this function simply try to solve immediately,
    # and bail if things don't work out.
    def solve(goal: NamedCType, *, direct: bool) -> str:
        def direct_solve(goal: NamedCType) -> str:
            return solve(goal, direct=True)

        if goal in ctx:
            # Trivial
            return ctx[goal]

        # const & is satisfied with mutable &
        if isinstance(goal.type, ConstRefCType):
            try:
                # WARNING: not strictly decreasing; be careful not
                # to add a direct conversion that goes satisfies
                # mutable& with const&
                return solve(NamedCType(goal.name,
                                        MutRefCType(goal.type.elem)),
                             direct=direct)
            except UnsatError:
                pass

        # mutable & is satisfied with value
        if isinstance(goal.type, MutRefCType):
            try:
                return solve(NamedCType(goal.name, goal.type.elem),
                             direct=direct)
            except UnsatError:
                pass

        if direct:
            unsat(goal)

        # For now, all of these rules are mutually exclusive.
        if goal == NamedCType("memory_format",
                              OptionalCType(BaseCType(memoryFormatT))):
            memory_format = direct_solve(
                NamedCType(SpecialArgName.possibly_redundant_memory_format,
                           OptionalCType(BaseCType(memoryFormatT))))
            # No need to join "memory_format" and "options" if the target API takes "options" directly.
            # Otherwise it will cause the redundant memory_format error.
            if options_ctype in goal_ctypes:
                return memory_format
            try:
                options = direct_solve(options_ctype)
                return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})"
            except UnsatError:
                return memory_format

        elif goal == NamedCType("options", BaseCType(tensorOptionsT)):
            dtype = direct_solve(
                NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))))
            pin_memory = direct_solve(
                NamedCType("pin_memory", OptionalCType(BaseCType(boolT))))
            device = direct_solve(
                NamedCType("device", OptionalCType(BaseCType(deviceT))))
            layout = direct_solve(
                NamedCType("layout", OptionalCType(BaseCType(layoutT))))
            return f'TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})'

        elif goal == NamedCType("dtype",
                                OptionalCType(BaseCType(scalarTypeT))):
            options = direct_solve(options_ctype)
            return f'optTypeMetaToScalarType({options}.dtype_opt())'

        elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))):
            options = direct_solve(options_ctype)
            return f'{options}.layout_opt()'

        elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))):
            options = direct_solve(options_ctype)
            return f'{options}.device_opt()'

        elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))):
            options = direct_solve(options_ctype)
            return f'{options}.pinned_memory_opt()'

        unsat(goal)

    return [Expr(solve(g, direct=False), g) for g in goal_ctypes]
Пример #20
0
    def solve(goal: NamedCType, *, direct: bool) -> str:
        def direct_solve(goal: NamedCType) -> str:
            return solve(goal, direct=True)

        if goal in ctx:
            # Trivial
            return ctx[goal]

        # const & is satisfied with mutable &
        if isinstance(goal.type, ConstRefCType):
            try:
                # WARNING: not strictly decreasing; be careful not
                # to add a direct conversion that goes satisfies
                # mutable& with const&
                return solve(NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct)
            except UnsatError:
                pass

        # mutable & is satisfied with value
        if isinstance(goal.type, MutRefCType):
            try:
                return solve(NamedCType(goal.name, goal.type.elem), direct=direct)
            except UnsatError:
                pass

        if direct:
            unsat(goal)

        # For now, all of these rules are mutually exclusive.
        if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))):
            memory_format = direct_solve(
                NamedCType(SpecialArgName.possibly_redundant_memory_format, OptionalCType(BaseCType(memoryFormatT)))
            )
            # No need to join "memory_format" and "options" if the target API takes "options" directly.
            # Otherwise it will cause the redundant memory_format error.
            if options_ctype in goal_ctypes:
                return memory_format
            try:
                options = direct_solve(options_ctype)
                return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})"
            except UnsatError:
                return memory_format

        elif goal == NamedCType("options", BaseCType(tensorOptionsT)):
            dtype = direct_solve(NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))))
            pin_memory = direct_solve(NamedCType("pin_memory", OptionalCType(BaseCType(boolT))))
            device = direct_solve(NamedCType("device", OptionalCType(BaseCType(deviceT))))
            layout = direct_solve(NamedCType("layout", OptionalCType(BaseCType(layoutT))))
            return f'TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})'

        elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))):
            options = direct_solve(options_ctype)
            return f'optTypeMetaToScalarType({options}.dtype_opt())'

        elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))):
            options = direct_solve(options_ctype)
            return f'{options}.layout_opt()'

        elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))):
            options = direct_solve(options_ctype)
            return f'{options}.device_opt()'

        elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))):
            options = direct_solve(options_ctype)
            return f'{options}.pinned_memory_opt()'

        unsat(goal)
Пример #21
0
} else if (prop.isIntegral(/*includeBool=*/false)) {
  return PyLong_FromLong(prop.to<int64_t>());
} else if (prop.isBoolean()) {
  if (prop.to<bool>()) {
    Py_RETURN_TRUE;
  } else {
    Py_RETURN_FALSE;
  }
} else {
  PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type");
  return nullptr;
}
"""

MISC_GETTER_DEFS = {
    OptionalCType(BaseCType(longT)):
    (GETTER_DEFINITION_OPT, GETTER_BODY_INT64_T),
    BaseCType(doubleT): (GETTER_DEFINITION, GETTER_BODY_DOUBLE),
    OptionalCType(BaseCType(doubleT)):
    (GETTER_DEFINITION_OPT, GETTER_BODY_DOUBLE),
    BaseCType(boolT): (GETTER_DEFINITION, GETTER_BODY_BOOL),
    BaseCType(scalarT): (GETTER_DEFINITION, GETTER_BODY_SCALAR),
    OptionalCType(BaseCType(scalarT)):
    (GETTER_DEFINITION_OPT, GETTER_BODY_SCALAR),
}

# These functions have backwards which cannot be traced, and so must have
# their backward functions traced opaquely.
# VIEW_FUNCTIONS are not traceable because they use as_strided, which
# has an untraceable backwards, see
# https://github.com/pytorch/pytorch/issues/4250
Пример #22
0
#   - Need the "dtype" binding?  Well, maybe "dtype" isn't available
#     in the context, instead, "options" is, and you need to extract
#     it from there.  (Gather)
#
#   - Need the "context" binding?  Well, maybe "context" isn't available
#     in the context, and you need to construct it from "dtype", "device",
#     etc.  (Scatter)
#
#   - Need the "memory_format" binding?  Well, actually, it's available
#     from both "memory_format" and "options", so you had better make sure
#     they are consistent.  (Join)

options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT)))

longVec_ctype = VectorCType(BaseCType(longT))
optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT)))
optionalScalar_ctype = OptionalCType(BaseCType(scalarT))
optionalTensor_ctype = OptionalCType(BaseCType(tensorT))

class UnsatError(RuntimeError):
    pass

# Given a set of in-scope bindings and a set of target bindings, synthesize
# a list of expressions that uses only the in-scope bindings (bindings) that
# have all of the types of goals.  You may want to use this function if
# you're generating code for a function like:
#
#   void f({args}) {
#     g({exprs}); // g is a different API
#   }
#
Пример #23
0
} else if (prop.isIntegral(/*includeBool=*/false)) {
  return PyLong_FromLong(prop.to<int64_t>());
} else if (prop.isBoolean()) {
  if (prop.to<bool>()) {
    Py_RETURN_TRUE;
  } else {
    Py_RETURN_FALSE;
  }
} else {
  PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type");
  return nullptr;
}
"""

MISC_GETTER_DEFS = {
    OptionalCType(BaseCType(intT)): (GETTER_DEFINITION_OPT, GETTER_BODY_INT64_T),
    BaseCType(doubleT): (GETTER_DEFINITION, GETTER_BODY_DOUBLE),
    OptionalCType(BaseCType(doubleT)): (GETTER_DEFINITION_OPT, GETTER_BODY_DOUBLE),
    BaseCType(boolT): (GETTER_DEFINITION, GETTER_BODY_BOOL),
    BaseCType(scalarT): (GETTER_DEFINITION, GETTER_BODY_SCALAR),
    OptionalCType(BaseCType(scalarT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SCALAR),
}

# These functions have backwards which cannot be traced, and so must have
# their backward functions traced opaquely.
# VIEW_FUNCTIONS are not traceable because they use as_strided, which
# has an untraceable backwards, see
# https://github.com/pytorch/pytorch/issues/4250
# TODO: This is probably not exhaustive, but it's a start
UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
Пример #24
0
def saved_variables(
    formula: str,
    nctypes: List[NamedCType],
    var_names: Tuple[str, ...],
) -> Tuple[str, Tuple[SavedAttribute, ...]]:
    def stride_expr(name: str) -> str:
        assert var_names == (name, ), (
            'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor '
            'that ".strides()" is being called on.')
        return f'strides_or_error({name}, "{name}")'

    REPLACEMENTS: List[Tuple[str, Dict[str, Any]]] = [
        # replace self.sizes() with self_sizes
        (r'{}.sizes\(\)', {
            'suffix': '_sizes',
            'nctype': lambda name: NamedCType(name, BaseCType(intArrayRefT)),
        }),
        # replace self.options() with self_options
        (r'{}.options\(\)', {
            'suffix': '_options',
            'nctype': lambda name: NamedCType(name, BaseCType(tensorOptionsT)),
        }),
        # replace zeros_like(self) with self_info
        (
            r'zeros_like\({}\)',
            {
                'suffix': '_info',
                'nctype':
                lambda name: NamedCType(name, BaseCType(typeAndSizeT)),
                'expr': lambda name: name,  # at save-time
                'res': lambda name: name + '_info.zeros()',  # at eval-time
            }),
        # replace self.size(2) with self_size_2
        (r'{}.size\((\w+)\)', {
            'suffix': lambda m: '_argsize_{}'.format(*m.groups()),
            'nctype': lambda name: NamedCType(name, BaseCType(intT)),
        }),
        # replace self.numel() with self_numel
        (r'{}.numel\(\)', {
            'suffix': '_numel',
            'nctype': lambda name: NamedCType(name, BaseCType(intT)),
        }),
        # replace to_args_sizes(self) with self_args_sizes
        (r'to_args_sizes\({}\)', {
            'suffix':
            '_args_sizes',
            'nctype':
            lambda name: NamedCType(name,
                                    VectorCType(VectorCType(BaseCType(intT)))),
        }),
        # replace to_args_scalartypes(self) with self_args_scalartypes
        (r'to_args_scalartypes\({}\)', {
            'suffix':
            '_args_scalartypes',
            'nctype':
            lambda name: NamedCType(name, VectorCType(BaseCType(scalarTypeT))),
        }),
        # replace TensorGeometry(self) with self_geometry
        (r'TensorGeometry\({}\)', {
            'suffix': '_geometry',
            'nctype':
            lambda name: NamedCType(name, BaseCType(tensorGeometryT)),
        }),
        (r'{}.scalar_type\(\)', {
            'suffix': '_scalar_type',
            'nctype': lambda name: NamedCType(name, BaseCType(scalarTypeT)),
        }),
        # replace self.dim() with self_dim
        (r'{}.dim\(\)', {
            'suffix': '_dim',
            'nctype': lambda name: NamedCType(name, BaseCType(intT)),
        }),
        # replace self.strides() with self_strides
        (r'{}.strides\(\)', {
            'suffix': '_strides',
            'nctype': lambda name: NamedCType(name, BaseCType(intArrayRefT)),
            'expr': stride_expr,
        }),
    ]

    # find which arguments need to be saved
    saved: List[SavedAttribute] = []

    for nctype in nctypes:
        name = nctype.name.name if isinstance(nctype.name,
                                              SpecialArgName) else nctype.name
        # First search the formula for expressions which can be evaluated
        # when the autograd Function is created to avoid saving variables
        for regex, info in REPLACEMENTS:

            def repl(m: Match[str]) -> str:
                suffix: str = info['suffix'](m) if callable(
                    info['suffix']) else info['suffix']
                expr: str = info['expr'](name) if 'expr' in info else m.group(
                    0)
                saved.append(
                    SavedAttribute(
                        nctype=info['nctype'](name + suffix),
                        expr=expr,
                    ))
                if 'res' in info:
                    replacement: str = info['res'](name)
                    return replacement
                return name + suffix

            formula = re.sub(regex.format(name), repl, formula)

        # c10::optional<std::string> types stored in Backward nodes must be
        # converted to c10::optional<c10::string_view> before being passed into
        # the backward function
        if nctype.type == OptionalCType(BaseCType(stringT)):
            formula = re.sub(
                rf'\b{name}\b',
                f'{name}.has_value() ? c10::optional<c10::string_view>({name}.value()) : c10::nullopt',
                formula)

        # Find any variables which remain in the formula and save them
        if re.search(IDENT_REGEX.format(name), formula):
            saved.append(SavedAttribute(
                nctype=nctype,
                expr=name,
            ))

    return formula, tuple(saved)
Пример #25
0
    def solve(goal: NamedCType, *, direct: bool) -> str:
        def direct_solve(goal: NamedCType) -> str:
            return solve(goal, direct=True)

        if goal in ctx:
            # Trivial
            return ctx[goal]

        # const & is satisfied with mutable &
        if isinstance(goal.type, ConstRefCType):
            try:
                # WARNING: not strictly decreasing; be careful not
                # to add a direct conversion that goes satisfies
                # mutable& with const&
                return solve(NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct)
            except UnsatError:
                pass

        # mutable & is satisfied with value
        if isinstance(goal.type, MutRefCType):
            try:
                return solve(NamedCType(goal.name, goal.type.elem), direct=direct)
            except UnsatError:
                pass

        if direct:
            unsat(goal)

        # For now, all of these rules are mutually exclusive.
        if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))):
            memory_format = direct_solve(
                NamedCType(SpecialArgName.possibly_redundant_memory_format, OptionalCType(BaseCType(memoryFormatT)))
            )
            # No need to join "memory_format" and "options" if the target API takes "options" directly.
            # Otherwise it will cause the redundant memory_format error.
            if options_ctype in goal_ctypes:
                return memory_format
            try:
                options = direct_solve(options_ctype)
                return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})"
            except UnsatError:
                return memory_format
        elif goal == NamedCType("options", BaseCType(tensorOptionsT)):
            dtype = direct_solve(NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))))
            pin_memory = direct_solve(NamedCType("pin_memory", OptionalCType(BaseCType(boolT))))
            device = direct_solve(NamedCType("device", OptionalCType(BaseCType(deviceT))))
            layout = direct_solve(NamedCType("layout", OptionalCType(BaseCType(layoutT))))
            return f'TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})'

        elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))):
            options = direct_solve(options_ctype)
            return f'optTypeMetaToScalarType({options}.dtype_opt())'

        elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))):
            options = direct_solve(options_ctype)
            return f'{options}.layout_opt()'

        elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))):
            options = direct_solve(options_ctype)
            return f'{options}.device_opt()'

        elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))):
            options = direct_solve(options_ctype)
            return f'{options}.pinned_memory_opt()'

        # We can always do translations from value types to reference types, like vector<int> -> IntArrayRef
        elif goal.type == BaseCType(intArrayRefT):
            return direct_solve(NamedCType(goal.name, longVec_ctype))
        elif goal.type == BaseCType(optionalIntArrayRefT):
            return direct_solve(NamedCType(goal.name, optionalLongVec_ctype))
        elif goal.type == BaseCType(optionalScalarRefT):
            return direct_solve(NamedCType(goal.name, optionalScalar_ctype))
        elif goal.type == BaseCType(optionalTensorRefT):
            return direct_solve(NamedCType(goal.name, optionalTensor_ctype))


        # Note [translation from C++ reference to value types]
        # The below cases are all for when we have an argument with a reference type,
        # and a corresponding goal with a value type.
        # These are needed when we populate the inputs to a lambda capture and we need
        # to guarantee the lifetime of each captured argument.
        # We guard it with an explicit kwarg because converting to a value type is expensive
        # (O(n)) to convert from IntArrayRef to vector<int>),
        # so the caller of translate() should be explicit that they need it.
        if allow_expensive_conversions:
            if goal.type == VectorCType(BaseCType(longT)):
                intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT))
                argname = direct_solve(intArrayRef_ctype)
                return f'{argname}.vec()'
            elif goal.type == OptionalCType(VectorCType(BaseCType(longT))):
                optionalIntArrayRef_ctype = NamedCType(goal.name, BaseCType(optionalIntArrayRefT))
                argname = direct_solve(optionalIntArrayRef_ctype)
                return f'{argname}.has_value() ? c10::make_optional({argname}->vec()) : c10::nullopt'
            elif goal.type == OptionalCType(BaseCType(scalarT)):
                optionalScalarRef_ctype = NamedCType(goal.name, BaseCType(optionalScalarRefT))
                argname = direct_solve(optionalScalarRef_ctype)
                return f'{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt'
            elif goal.type == OptionalCType(BaseCType(scalarT)):
                optionalTensorRef_ctype = NamedCType(goal.name, BaseCType(optionalTensorRefT))
                argname = direct_solve(optionalTensorRef_ctype)
                return f'{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt'
            # Technically, we also need to handle cases of C++ containers holding reference types.
            # But there currently aren't any ops that require lambda capture codegen
            # With arguments like std::vector<IntArrayRef>.
            # If that changes, we'll have to add the translation here.

        unsat(goal)
Пример #26
0
#   - Need the "dtype" binding?  Well, maybe "dtype" isn't available
#     in the context, instead, "options" is, and you need to extract
#     it from there.  (Gather)
#
#   - Need the "context" binding?  Well, maybe "context" isn't available
#     in the context, and you need to construct it from "dtype", "device",
#     etc.  (Scatter)
#
#   - Need the "memory_format" binding?  Well, actually, it's available
#     from both "memory_format" and "options", so you had better make sure
#     they are consistent.  (Join)

options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT)))

longVec_ctype = VectorCType(BaseCType(longT))
optionalScalar_ctype = OptionalCType(BaseCType(scalarT))
optionalTensor_ctype = OptionalCType(BaseCType(tensorT))


class UnsatError(RuntimeError):
    pass


# Given a set of in-scope bindings and a set of target bindings, synthesize
# a list of expressions that uses only the in-scope bindings (bindings) that
# have all of the types of goals.  You may want to use this function if
# you're generating code for a function like:
#
#   void f({args}) {
#     g({exprs}); // g is a different API
#   }
Пример #27
0
    def check_tensorimpl_and_storage(call: str,
                                     unpacked_bindings: List[Binding]) -> str:
        # See NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
        stmts_before_call: List[str] = []
        stmts_after_call: List[str] = []

        if cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
            return call

        # Check properties of inputs (enforce (1))
        for unpacked_binding in unpacked_bindings:
            arg = unpacked_binding.name
            noref_cpp_type = unpacked_binding.nctype.type.remove_const_ref()
            if noref_cpp_type == BaseCType(tensorListT):
                stmts_before_call += [
                    SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
                    SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg)
                ]
                stmts_after_call += [
                    ENFORCE_SAME_TENSORLIST_STORAGE.substitute(
                        tensorlist_name=arg),
                    ENFORCE_SAME_TENSORLIST_IMPL.substitute(
                        tensorlist_name=arg)
                ]
            elif noref_cpp_type == ListCType(OptionalCType(
                    BaseCType(tensorT))):
                stmts_before_call += [
                    SAVE_OPTIONALTENSORLIST_STORAGE.substitute(
                        tensorlist_name=arg),
                    SAVE_OPTIONALTENSORLIST_IMPL.substitute(
                        tensorlist_name=arg)
                ]
                stmts_after_call += [
                    ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute(
                        tensorlist_name=arg),
                    ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute(
                        tensorlist_name=arg)
                ]
            elif noref_cpp_type == BaseCType(tensorT):
                stmts_before_call += [
                    SAVE_TENSOR_STORAGE.substitute(tensor_name=arg),
                    SAVE_TENSOR_IMPL.substitute(tensor_name=arg)
                ]
                stmts_after_call += [
                    ENFORCE_SAME_TENSOR_STORAGE.substitute(
                        tensor_name=arg, out_tensor_name=arg),
                    ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg)
                ]

        assert (stmts_before_call
                and stmts_after_call) or (not stmts_before_call
                                          and not stmts_after_call)

        # Check properties of outputs (enforce (2), (3))
        if not f.func.kind() in (SchemaKind.inplace, SchemaKind.out):
            base_name = f.func.name.name.base  # TODO: should be str(f.func.name.name)?
            aliased_arg_name = ALL_VIEW_FUNCTIONS.get(base_name, None)
            if aliased_arg_name is not None:
                aliased_arg_name = unpacked_name(aliased_arg_name)
            for i, (ret, ret_name) in enumerate(
                    zip(f.func.returns, cpp.return_names(f))):
                noref_cpp_type = cpp.return_type(ret).remove_const_ref()
                if noref_cpp_type == BaseCType(tensorT):
                    if aliased_arg_name is not None:
                        assert i == 0, "Expect non-CompositeImplicitAutograd view function {base} to return single output"
                        stmts_after_call += [
                            ENFORCE_SAME_TENSOR_STORAGE.substitute(
                                tensor_name=aliased_arg_name,
                                out_tensor_name=ret_name)
                        ]
                    else:
                        if type_wrapper_name(
                                f) not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT:
                            stmts_after_call += [
                                ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE.
                                substitute(tensor_name=ret_name,
                                           fn_name=type_wrapper_name(f))
                            ]

                    if type_wrapper_name(
                            f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT:
                        stmts_after_call += [
                            ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE.
                            substitute(tensor_name=ret_name,
                                       fn_name=type_wrapper_name(f))
                        ]

                # Currently we don't have any functions that return the following types, but
                # we should update the checks once we do
                elif noref_cpp_type == ListCType(
                        OptionalCType(BaseCType(tensorT))):
                    raise AssertionError(
                        f"Please add use_count checks for {noref_cpp_type}")
                elif noref_cpp_type == BaseCType(tensorListT):
                    raise AssertionError(
                        f"Please add use_count checks for {noref_cpp_type}")

        if stmts_before_call and stmts_after_call:
            call = RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_before_call) + \
                call + \
                RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_after_call)
        return call
Пример #28
0
def translate(
    bindings: Sequence[Union[Expr, Binding]],
    goals: Sequence[Union[NamedCType, Binding]],
    *, method: bool = False,
    allow_expensive_conversions: bool = False
) -> List[Expr]:

    binding_exprs: List[Expr] = []
    for b in bindings:
        if isinstance(b, Binding):
            binding_exprs.append(Expr(
                expr=b.name,
                type=b.nctype,
            ))
        else:
            binding_exprs.append(b)

    goal_ctypes: List[NamedCType] = []
    for g in goals:
        if isinstance(g, Binding):
            goal_ctypes.append(g.nctype)
        else:
            goal_ctypes.append(g)

    # Add all the bindings to the context
    ctx: Dict[NamedCType, str] = {}
    for b in binding_exprs:
        ctx[b.type] = b.expr

        # While we're at it, do some simple forward inference, looking through
        # constructors.
        #
        # NB: When should you do forward inference versus backward inference?
        # The general idea:
        #
        #   - Backward inference WHEN the goal gets smaller
        #   - Forward inference WHEN the hypothesis gets smaller
        #
        # This helps ensure termination: backward inference starts with a goal
        # and tries to make it simpler and simpler until it's trivial; if the
        # goal can grow in size, we blow up to a really huge goal size.
        # Similarly, with forward inference we take hypotheses and decompose
        # them into simpler hypotheses; if hypotheses could expand in size,
        # we also have potential nontermination.  (In the code below, forward
        # inference is only ever carried out at a single step, but you could
        # imagine repeated application of forward inference being profitable.)
        #
        # A good starting point in the literature for exploring more about proof
        # search are these lecture notes
        # https://www.cs.cmu.edu/~fp/courses/oregon-m10/04-focusing.pdf
        #
        # TODO: My kingdom for a pattern matcher
        # https://www.python.org/dev/peps/pep-0634/
        #
        # TODO: This could get us in recomputation trouble if b.expr is nontrivial.
        # Fix this by implementing some sort of sharing so that if multiple
        # goals share the same expression, we only compute it once.  This seems
        # to matter in practice as compiler is often unwilling to CSE nontrivial
        # expressions like scalar.to<scalar_t>()
        t = b.type
        if isinstance(t, ConstRefCType) and isinstance(t.elem, OptionalCType) and \
                isinstance(t.elem.elem, BaseCType) and str(t.elem.elem.type) == 'at::Tensor':
            ctx[NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))] = \
                f'({b.expr}.has_value() ? *{b.expr} : at::Tensor())'

        if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))):
            ctx[NamedCType(t.name, BaseCType(optionalTensorRefT))] = \
                f'(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())'

        if t.type == ConstRefCType(BaseCType(scalarT)):
            ctx[NamedCType(t.name, BaseCType(opmath_t))] = f'({b.expr}).to<opmath_t>()'

        if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))):
            ctx[NamedCType(t.name, BaseCType(optionalScalarRefT))] = \
                f'({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())'

        if t.type == BaseCType(scalar_t):
            ctx[NamedCType(t.name, BaseCType(opmath_t))] = f'static_cast<opmath_t>({b.expr})'

    # Add implicit bindings if the generated code is inside a Tensor method
    if method:
        ctx[NamedCType("self", MutRefCType(BaseCType(tensorT)))] = "const_cast<Tensor&>(*this)"
        ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "const_cast<Tensor&>(*this)"
        # This is better!  Byte-for-byte compat
        # ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this"

    def unsat(goal: NamedCType) -> NoReturn:
        ctx_desc = '\n'.join(f"  {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items())
        raise UnsatError(f'''
Failed to synthesize the expression "{goal.cpp_type()} {goal.name}".
When I failed, the following bindings were available in the context:

{ctx_desc}

This probably means there is a missing rule in the rules of tools.codegen.api.translate.
Check this module for more information.
''')

    # A shitty backtracking search implementation.  It's shitty because it
    # does backtracking via stack (bad idea!) and for the most part tries to
    # avoid backtracking.  In particular, if
    # direct=True, we won't try to do any fancy synthesis, just trivial
    # conversions (e.g., "T a" is OK for "const T& a").  So all of the
    # existing rules in this function simply try to solve immediately,
    # and bail if things don't work out.
    def solve(goal: NamedCType, *, direct: bool) -> str:
        def direct_solve(goal: NamedCType) -> str:
            return solve(goal, direct=True)

        if goal in ctx:
            # Trivial
            return ctx[goal]

        # const & is satisfied with mutable &
        if isinstance(goal.type, ConstRefCType):
            try:
                # WARNING: not strictly decreasing; be careful not
                # to add a direct conversion that goes satisfies
                # mutable& with const&
                return solve(NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct)
            except UnsatError:
                pass

        # mutable & is satisfied with value
        if isinstance(goal.type, MutRefCType):
            try:
                return solve(NamedCType(goal.name, goal.type.elem), direct=direct)
            except UnsatError:
                pass

        if direct:
            unsat(goal)

        # For now, all of these rules are mutually exclusive.
        if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))):
            memory_format = direct_solve(
                NamedCType(SpecialArgName.possibly_redundant_memory_format, OptionalCType(BaseCType(memoryFormatT)))
            )
            # No need to join "memory_format" and "options" if the target API takes "options" directly.
            # Otherwise it will cause the redundant memory_format error.
            if options_ctype in goal_ctypes:
                return memory_format
            try:
                options = direct_solve(options_ctype)
                return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})"
            except UnsatError:
                return memory_format
        elif goal == NamedCType("options", BaseCType(tensorOptionsT)):
            dtype = direct_solve(NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))))
            pin_memory = direct_solve(NamedCType("pin_memory", OptionalCType(BaseCType(boolT))))
            device = direct_solve(NamedCType("device", OptionalCType(BaseCType(deviceT))))
            layout = direct_solve(NamedCType("layout", OptionalCType(BaseCType(layoutT))))
            return f'TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})'

        elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))):
            options = direct_solve(options_ctype)
            return f'optTypeMetaToScalarType({options}.dtype_opt())'

        elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))):
            options = direct_solve(options_ctype)
            return f'{options}.layout_opt()'

        elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))):
            options = direct_solve(options_ctype)
            return f'{options}.device_opt()'

        elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))):
            options = direct_solve(options_ctype)
            return f'{options}.pinned_memory_opt()'

        # We can always do translations from value types to reference types, like vector<int> -> IntArrayRef
        elif goal.type == BaseCType(intArrayRefT):
            return direct_solve(NamedCType(goal.name, longVec_ctype))
        elif goal.type == BaseCType(optionalIntArrayRefT):
            return direct_solve(NamedCType(goal.name, optionalLongVec_ctype))
        elif goal.type == BaseCType(optionalScalarRefT):
            return direct_solve(NamedCType(goal.name, optionalScalar_ctype))
        elif goal.type == BaseCType(optionalTensorRefT):
            return direct_solve(NamedCType(goal.name, optionalTensor_ctype))


        # Note [translation from C++ reference to value types]
        # The below cases are all for when we have an argument with a reference type,
        # and a corresponding goal with a value type.
        # These are needed when we populate the inputs to a lambda capture and we need
        # to guarantee the lifetime of each captured argument.
        # We guard it with an explicit kwarg because converting to a value type is expensive
        # (O(n)) to convert from IntArrayRef to vector<int>),
        # so the caller of translate() should be explicit that they need it.
        if allow_expensive_conversions:
            if goal.type == VectorCType(BaseCType(longT)):
                intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT))
                argname = direct_solve(intArrayRef_ctype)
                return f'{argname}.vec()'
            elif goal.type == OptionalCType(VectorCType(BaseCType(longT))):
                optionalIntArrayRef_ctype = NamedCType(goal.name, BaseCType(optionalIntArrayRefT))
                argname = direct_solve(optionalIntArrayRef_ctype)
                return f'{argname}.has_value() ? c10::make_optional({argname}->vec()) : c10::nullopt'
            elif goal.type == OptionalCType(BaseCType(scalarT)):
                optionalScalarRef_ctype = NamedCType(goal.name, BaseCType(optionalScalarRefT))
                argname = direct_solve(optionalScalarRef_ctype)
                return f'{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt'
            elif goal.type == OptionalCType(BaseCType(scalarT)):
                optionalTensorRef_ctype = NamedCType(goal.name, BaseCType(optionalTensorRefT))
                argname = direct_solve(optionalTensorRef_ctype)
                return f'{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt'
            # Technically, we also need to handle cases of C++ containers holding reference types.
            # But there currently aren't any ops that require lambda capture codegen
            # With arguments like std::vector<IntArrayRef>.
            # If that changes, we'll have to add the translation here.

        unsat(goal)

    return [Expr(solve(g, direct=False), g) for g in goal_ctypes]