def translate( bindings: Sequence[Union[Expr, Binding]], goals: Sequence[Union[NamedCType, Binding]], *, method: bool = False, allow_expensive_conversions: bool = False ) -> List[Expr]: binding_exprs: List[Expr] = [] for b in bindings: if isinstance(b, Binding): binding_exprs.append(Expr( expr=b.name, type=b.nctype, )) else: binding_exprs.append(b) goal_ctypes: List[NamedCType] = [] for g in goals: if isinstance(g, Binding): goal_ctypes.append(g.nctype) else: goal_ctypes.append(g) # Add all the bindings to the context ctx: Dict[NamedCType, str] = {} for b in binding_exprs: ctx[b.type] = b.expr # While we're at it, do some simple forward inference, looking through # constructors. # # NB: When should you do forward inference versus backward inference? # The general idea: # # - Backward inference WHEN the goal gets smaller # - Forward inference WHEN the hypothesis gets smaller # # This helps ensure termination: backward inference starts with a goal # and tries to make it simpler and simpler until it's trivial; if the # goal can grow in size, we blow up to a really huge goal size. # Similarly, with forward inference we take hypotheses and decompose # them into simpler hypotheses; if hypotheses could expand in size, # we also have potential nontermination. (In the code below, forward # inference is only ever carried out at a single step, but you could # imagine repeated application of forward inference being profitable.) # # A good starting point in the literature for exploring more about proof # search are these lecture notes # https://www.cs.cmu.edu/~fp/courses/oregon-m10/04-focusing.pdf # # TODO: My kingdom for a pattern matcher # https://www.python.org/dev/peps/pep-0634/ # # TODO: This could get us in recomputation trouble if b.expr is nontrivial. # Fix this by implementing some sort of sharing so that if multiple # goals share the same expression, we only compute it once. This seems # to matter in practice as compiler is often unwilling to CSE nontrivial # expressions like scalar.to<scalar_t>() t = b.type if isinstance(t, ConstRefCType) and isinstance(t.elem, OptionalCType) and \ isinstance(t.elem.elem, BaseCType) and str(t.elem.elem.type) == 'at::Tensor': ctx[NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))] = \ f'({b.expr}.has_value() ? *{b.expr} : at::Tensor())' if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))): ctx[NamedCType(t.name, BaseCType(optionalTensorRefT))] = \ f'(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())' if t.type == ConstRefCType(BaseCType(scalarT)): ctx[NamedCType(t.name, BaseCType(opmath_t))] = f'({b.expr}).to<opmath_t>()' if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))): ctx[NamedCType(t.name, BaseCType(optionalScalarRefT))] = \ f'({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())' if t.type == BaseCType(scalar_t): ctx[NamedCType(t.name, BaseCType(opmath_t))] = f'static_cast<opmath_t>({b.expr})' # Add implicit bindings if the generated code is inside a Tensor method if method: ctx[NamedCType("self", MutRefCType(BaseCType(tensorT)))] = "const_cast<Tensor&>(*this)" ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "const_cast<Tensor&>(*this)" # This is better! Byte-for-byte compat # ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this" def unsat(goal: NamedCType) -> NoReturn: ctx_desc = '\n'.join(f" {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items()) raise UnsatError(f''' Failed to synthesize the expression "{goal.cpp_type()} {goal.name}". When I failed, the following bindings were available in the context: {ctx_desc} This probably means there is a missing rule in the rules of tools.codegen.api.translate. Check this module for more information. ''') # A shitty backtracking search implementation. It's shitty because it # does backtracking via stack (bad idea!) and for the most part tries to # avoid backtracking. In particular, if # direct=True, we won't try to do any fancy synthesis, just trivial # conversions (e.g., "T a" is OK for "const T& a"). So all of the # existing rules in this function simply try to solve immediately, # and bail if things don't work out. def solve(goal: NamedCType, *, direct: bool) -> str: def direct_solve(goal: NamedCType) -> str: return solve(goal, direct=True) if goal in ctx: # Trivial return ctx[goal] # const & is satisfied with mutable & if isinstance(goal.type, ConstRefCType): try: # WARNING: not strictly decreasing; be careful not # to add a direct conversion that goes satisfies # mutable& with const& return solve(NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct) except UnsatError: pass # mutable & is satisfied with value if isinstance(goal.type, MutRefCType): try: return solve(NamedCType(goal.name, goal.type.elem), direct=direct) except UnsatError: pass if direct: unsat(goal) # For now, all of these rules are mutually exclusive. if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))): memory_format = direct_solve( NamedCType(SpecialArgName.possibly_redundant_memory_format, OptionalCType(BaseCType(memoryFormatT))) ) # No need to join "memory_format" and "options" if the target API takes "options" directly. # Otherwise it will cause the redundant memory_format error. if options_ctype in goal_ctypes: return memory_format try: options = direct_solve(options_ctype) return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})" except UnsatError: return memory_format elif goal == NamedCType("options", BaseCType(tensorOptionsT)): dtype = direct_solve(NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT)))) pin_memory = direct_solve(NamedCType("pin_memory", OptionalCType(BaseCType(boolT)))) device = direct_solve(NamedCType("device", OptionalCType(BaseCType(deviceT)))) layout = direct_solve(NamedCType("layout", OptionalCType(BaseCType(layoutT)))) return f'TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})' elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))): options = direct_solve(options_ctype) return f'optTypeMetaToScalarType({options}.dtype_opt())' elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))): options = direct_solve(options_ctype) return f'{options}.layout_opt()' elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))): options = direct_solve(options_ctype) return f'{options}.device_opt()' elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))): options = direct_solve(options_ctype) return f'{options}.pinned_memory_opt()' # We can always do translations from value types to reference types, like vector<int> -> IntArrayRef elif goal.type == BaseCType(intArrayRefT): return direct_solve(NamedCType(goal.name, longVec_ctype)) elif goal.type == BaseCType(optionalIntArrayRefT): return direct_solve(NamedCType(goal.name, optionalLongVec_ctype)) elif goal.type == BaseCType(optionalScalarRefT): return direct_solve(NamedCType(goal.name, optionalScalar_ctype)) elif goal.type == BaseCType(optionalTensorRefT): return direct_solve(NamedCType(goal.name, optionalTensor_ctype)) # Note [translation from C++ reference to value types] # The below cases are all for when we have an argument with a reference type, # and a corresponding goal with a value type. # These are needed when we populate the inputs to a lambda capture and we need # to guarantee the lifetime of each captured argument. # We guard it with an explicit kwarg because converting to a value type is expensive # (O(n)) to convert from IntArrayRef to vector<int>), # so the caller of translate() should be explicit that they need it. if allow_expensive_conversions: if goal.type == VectorCType(BaseCType(longT)): intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT)) argname = direct_solve(intArrayRef_ctype) return f'{argname}.vec()' elif goal.type == OptionalCType(VectorCType(BaseCType(longT))): optionalIntArrayRef_ctype = NamedCType(goal.name, BaseCType(optionalIntArrayRefT)) argname = direct_solve(optionalIntArrayRef_ctype) return f'{argname}.has_value() ? c10::make_optional({argname}->vec()) : c10::nullopt' elif goal.type == OptionalCType(BaseCType(scalarT)): optionalScalarRef_ctype = NamedCType(goal.name, BaseCType(optionalScalarRefT)) argname = direct_solve(optionalScalarRef_ctype) return f'{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt' elif goal.type == OptionalCType(BaseCType(scalarT)): optionalTensorRef_ctype = NamedCType(goal.name, BaseCType(optionalTensorRefT)) argname = direct_solve(optionalTensorRef_ctype) return f'{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt' # Technically, we also need to handle cases of C++ containers holding reference types. # But there currently aren't any ops that require lambda capture codegen # With arguments like std::vector<IntArrayRef>. # If that changes, we'll have to add the translation here. unsat(goal) return [Expr(solve(g, direct=False), g) for g in goal_ctypes]
def saved_variables( formula: str, nctypes: List[NamedCType], var_names: Tuple[str, ...], ) -> Tuple[str, Tuple[SavedAttribute, ...]]: def stride_expr(name: str) -> str: assert var_names == (name, ), ( 'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor ' 'that ".strides()" is being called on.') return f'strides_or_error({name}, "{name}")' REPLACEMENTS: List[Tuple[str, Dict[str, Any]]] = [ # replace self.sizes() with self_sizes (r'{}.sizes\(\)', { 'suffix': '_sizes', 'nctype': lambda name: NamedCType(name, BaseCType(intArrayRefT)), }), # replace self.options() with self_options (r'{}.options\(\)', { 'suffix': '_options', 'nctype': lambda name: NamedCType(name, BaseCType(tensorOptionsT)), }), # replace zeros_like(self) with self_info ( r'zeros_like\({}\)', { 'suffix': '_info', 'nctype': lambda name: NamedCType(name, BaseCType(typeAndSizeT)), 'expr': lambda name: name, # at save-time 'res': lambda name: name + '_info.zeros()', # at eval-time }), # replace self.size(2) with self_size_2 (r'{}.size\((\w+)\)', { 'suffix': lambda m: '_argsize_{}'.format(*m.groups()), 'nctype': lambda name: NamedCType(name, BaseCType(intT)), }), # replace self.numel() with self_numel (r'{}.numel\(\)', { 'suffix': '_numel', 'nctype': lambda name: NamedCType(name, BaseCType(intT)), }), # replace to_args_sizes(self) with self_args_sizes (r'to_args_sizes\({}\)', { 'suffix': '_args_sizes', 'nctype': lambda name: NamedCType(name, VectorCType(VectorCType(BaseCType(intT)))), }), # replace to_args_scalartypes(self) with self_args_scalartypes (r'to_args_scalartypes\({}\)', { 'suffix': '_args_scalartypes', 'nctype': lambda name: NamedCType(name, VectorCType(BaseCType(scalarTypeT))), }), # replace TensorGeometry(self) with self_geometry (r'TensorGeometry\({}\)', { 'suffix': '_geometry', 'nctype': lambda name: NamedCType(name, BaseCType(tensorGeometryT)), }), (r'{}.scalar_type\(\)', { 'suffix': '_scalar_type', 'nctype': lambda name: NamedCType(name, BaseCType(scalarTypeT)), }), # replace self.dim() with self_dim (r'{}.dim\(\)', { 'suffix': '_dim', 'nctype': lambda name: NamedCType(name, BaseCType(intT)), }), # replace self.strides() with self_strides (r'{}.strides\(\)', { 'suffix': '_strides', 'nctype': lambda name: NamedCType(name, BaseCType(intArrayRefT)), 'expr': stride_expr, }), ] # find which arguments need to be saved saved: List[SavedAttribute] = [] for nctype in nctypes: name = nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name # First search the formula for expressions which can be evaluated # when the autograd Function is created to avoid saving variables for regex, info in REPLACEMENTS: def repl(m: Match[str]) -> str: suffix: str = info['suffix'](m) if callable( info['suffix']) else info['suffix'] expr: str = info['expr'](name) if 'expr' in info else m.group( 0) saved.append( SavedAttribute( nctype=info['nctype'](name + suffix), expr=expr, )) if 'res' in info: replacement: str = info['res'](name) return replacement return name + suffix formula = re.sub(regex.format(name), repl, formula) # Find any variables which remain in the formula and save them if re.search(IDENT_REGEX.format(name), formula): saved.append(SavedAttribute( nctype=nctype, expr=name, )) return formula, tuple(saved)
# other scope); others are more nontrivial and may require packing/unpacking. # Some examples of non-trivial action: # # - Need the "dtype" binding? Well, maybe "dtype" isn't available # in the context, instead, "options" is, and you need to extract # it from there. (Gather) # # - Need the "context" binding? Well, maybe "context" isn't available # in the context, and you need to construct it from "dtype", "device", # etc. (Scatter) # # - Need the "memory_format" binding? Well, actually, it's available # from both "memory_format" and "options", so you had better make sure # they are consistent. (Join) options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT))) longVec_ctype = VectorCType(BaseCType(longT)) optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT))) optionalScalar_ctype = OptionalCType(BaseCType(scalarT)) optionalTensor_ctype = OptionalCType(BaseCType(tensorT)) class UnsatError(RuntimeError): pass # Given a set of in-scope bindings and a set of target bindings, synthesize # a list of expressions that uses only the in-scope bindings (bindings) that # have all of the types of goals. You may want to use this function if # you're generating code for a function like: # # void f({args}) {
def solve(goal: NamedCType, *, direct: bool) -> str: def direct_solve(goal: NamedCType) -> str: return solve(goal, direct=True) if goal in ctx: # Trivial return ctx[goal] # const & is satisfied with mutable & if isinstance(goal.type, ConstRefCType): try: # WARNING: not strictly decreasing; be careful not # to add a direct conversion that goes satisfies # mutable& with const& return solve(NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct) except UnsatError: pass # mutable & is satisfied with value if isinstance(goal.type, MutRefCType): try: return solve(NamedCType(goal.name, goal.type.elem), direct=direct) except UnsatError: pass if direct: unsat(goal) # For now, all of these rules are mutually exclusive. if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))): memory_format = direct_solve( NamedCType(SpecialArgName.possibly_redundant_memory_format, OptionalCType(BaseCType(memoryFormatT))) ) # No need to join "memory_format" and "options" if the target API takes "options" directly. # Otherwise it will cause the redundant memory_format error. if options_ctype in goal_ctypes: return memory_format try: options = direct_solve(options_ctype) return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})" except UnsatError: return memory_format elif goal == NamedCType("options", BaseCType(tensorOptionsT)): dtype = direct_solve(NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT)))) pin_memory = direct_solve(NamedCType("pin_memory", OptionalCType(BaseCType(boolT)))) device = direct_solve(NamedCType("device", OptionalCType(BaseCType(deviceT)))) layout = direct_solve(NamedCType("layout", OptionalCType(BaseCType(layoutT)))) return f'TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})' elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))): options = direct_solve(options_ctype) return f'optTypeMetaToScalarType({options}.dtype_opt())' elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))): options = direct_solve(options_ctype) return f'{options}.layout_opt()' elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))): options = direct_solve(options_ctype) return f'{options}.device_opt()' elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))): options = direct_solve(options_ctype) return f'{options}.pinned_memory_opt()' # We can always do translations from value types to reference types, like vector<int> -> IntArrayRef elif goal.type == BaseCType(intArrayRefT): return direct_solve(NamedCType(goal.name, longVec_ctype)) elif goal.type == BaseCType(optionalIntArrayRefT): return direct_solve(NamedCType(goal.name, optionalLongVec_ctype)) elif goal.type == BaseCType(optionalScalarRefT): return direct_solve(NamedCType(goal.name, optionalScalar_ctype)) elif goal.type == BaseCType(optionalTensorRefT): return direct_solve(NamedCType(goal.name, optionalTensor_ctype)) # Note [translation from C++ reference to value types] # The below cases are all for when we have an argument with a reference type, # and a corresponding goal with a value type. # These are needed when we populate the inputs to a lambda capture and we need # to guarantee the lifetime of each captured argument. # We guard it with an explicit kwarg because converting to a value type is expensive # (O(n)) to convert from IntArrayRef to vector<int>), # so the caller of translate() should be explicit that they need it. if allow_expensive_conversions: if goal.type == VectorCType(BaseCType(longT)): intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT)) argname = direct_solve(intArrayRef_ctype) return f'{argname}.vec()' elif goal.type == OptionalCType(VectorCType(BaseCType(longT))): optionalIntArrayRef_ctype = NamedCType(goal.name, BaseCType(optionalIntArrayRefT)) argname = direct_solve(optionalIntArrayRef_ctype) return f'{argname}.has_value() ? c10::make_optional({argname}->vec()) : c10::nullopt' elif goal.type == OptionalCType(BaseCType(scalarT)): optionalScalarRef_ctype = NamedCType(goal.name, BaseCType(optionalScalarRefT)) argname = direct_solve(optionalScalarRef_ctype) return f'{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt' elif goal.type == OptionalCType(BaseCType(scalarT)): optionalTensorRef_ctype = NamedCType(goal.name, BaseCType(optionalTensorRefT)) argname = direct_solve(optionalTensorRef_ctype) return f'{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt' # Technically, we also need to handle cases of C++ containers holding reference types. # But there currently aren't any ops that require lambda capture codegen # With arguments like std::vector<IntArrayRef>. # If that changes, we'll have to add the translation here. unsat(goal)
# - The lambda capture has to convert reference types to value types # - While the forward lambda just directly calls into the at::_ops API # (following the dispatcher convention), the logic here for the reverse lambda # is responsible for generating both the call-site, and the declarations # (which are implemented manually in the at::functionalization::impl namespace). # The lambdas generated for each view op in the functionalization pass are of the form # [capture_arguments](outer_arguments) -> returns_type { # return name(inner_arguments); # } # Define some specific lambda input arguments. base_binding = Binding(name='base', nctype=NamedCType(name='base', type=ConstRefCType( BaseCType(tensorT))), argument=Argument(name='base', type=BaseType(BaseTy.Tensor), default=None, annotation=None), default=None) mutated_view_binding = Binding(name='mutated_view', nctype=NamedCType(name='mutated_view', type=ConstRefCType( BaseCType(tensorT))), argument=Argument(name='base', type=BaseType(BaseTy.Tensor), default=None, annotation=None), default=None) mutated_view_idx_binding = Binding(name='mutated_view_idx',
} else if (prop.isIntegral(/*includeBool=*/false)) { return PyLong_FromLong(prop.to<int64_t>()); } else if (prop.isBoolean()) { if (prop.to<bool>()) { Py_RETURN_TRUE; } else { Py_RETURN_FALSE; } } else { PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type"); return nullptr; } """ MISC_GETTER_DEFS = { OptionalCType(BaseCType(intT)): (GETTER_DEFINITION_OPT, GETTER_BODY_INT64_T), BaseCType(doubleT): (GETTER_DEFINITION, GETTER_BODY_DOUBLE), OptionalCType(BaseCType(doubleT)): (GETTER_DEFINITION_OPT, GETTER_BODY_DOUBLE), BaseCType(boolT): (GETTER_DEFINITION, GETTER_BODY_BOOL), BaseCType(scalarT): (GETTER_DEFINITION, GETTER_BODY_SCALAR), OptionalCType(BaseCType(scalarT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SCALAR), } # These functions have backwards which cannot be traced, and so must have # their backward functions traced opaquely. # VIEW_FUNCTIONS are not traceable because they use as_strided, which # has an untraceable backwards, see # https://github.com/pytorch/pytorch/issues/4250 # TODO: This is probably not exhaustive, but it's a start UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
def check_tensorimpl_and_storage(call: str, unpacked_bindings: List[Binding]) -> str: # See NOTE [ TensorImpl and Storage Pointer Sanity Checks ] stmts_before_call: List[str] = [] stmts_after_call: List[str] = [] if cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE: return call # Check properties of inputs (enforce (1)) for unpacked_binding in unpacked_bindings: arg = unpacked_binding.name noref_cpp_type = unpacked_binding.nctype.type.remove_const_ref() if noref_cpp_type == BaseCType(tensorListT): stmts_before_call += [ SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg), SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg) ] stmts_after_call += [ ENFORCE_SAME_TENSORLIST_STORAGE.substitute( tensorlist_name=arg), ENFORCE_SAME_TENSORLIST_IMPL.substitute( tensorlist_name=arg) ] elif noref_cpp_type == ListCType(OptionalCType( BaseCType(tensorT))): stmts_before_call += [ SAVE_OPTIONALTENSORLIST_STORAGE.substitute( tensorlist_name=arg), SAVE_OPTIONALTENSORLIST_IMPL.substitute( tensorlist_name=arg) ] stmts_after_call += [ ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute( tensorlist_name=arg), ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute( tensorlist_name=arg) ] elif noref_cpp_type == BaseCType(tensorT): stmts_before_call += [ SAVE_TENSOR_STORAGE.substitute(tensor_name=arg), SAVE_TENSOR_IMPL.substitute(tensor_name=arg) ] stmts_after_call += [ ENFORCE_SAME_TENSOR_STORAGE.substitute( tensor_name=arg, out_tensor_name=arg), ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg) ] assert (stmts_before_call and stmts_after_call) or (not stmts_before_call and not stmts_after_call) # Check properties of outputs (enforce (2), (3)) if not f.func.kind() in (SchemaKind.inplace, SchemaKind.out): base_name = f.func.name.name.base # TODO: should be str(f.func.name.name)? aliased_arg_name = ALL_VIEW_FUNCTIONS.get(base_name, None) if aliased_arg_name is not None: aliased_arg_name = unpacked_name(aliased_arg_name) for i, (ret, ret_name) in enumerate( zip(f.func.returns, cpp.return_names(f))): noref_cpp_type = cpp.return_type(ret).remove_const_ref() if noref_cpp_type == BaseCType(tensorT): if aliased_arg_name is not None: assert i == 0, "Expect non-CompositeImplicitAutograd view function {base} to return single output" stmts_after_call += [ ENFORCE_SAME_TENSOR_STORAGE.substitute( tensor_name=aliased_arg_name, out_tensor_name=ret_name) ] else: if type_wrapper_name( f) not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT: stmts_after_call += [ ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE. substitute(tensor_name=ret_name, fn_name=type_wrapper_name(f)) ] if type_wrapper_name( f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT: stmts_after_call += [ ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE. substitute(tensor_name=ret_name, fn_name=type_wrapper_name(f)) ] # Currently we don't have any functions that return the following types, but # we should update the checks once we do elif noref_cpp_type == ListCType( OptionalCType(BaseCType(tensorT))): raise AssertionError( f"Please add use_count checks for {noref_cpp_type}") elif noref_cpp_type == BaseCType(tensorListT): raise AssertionError( f"Please add use_count checks for {noref_cpp_type}") if stmts_before_call and stmts_after_call: call = RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_before_call) + \ call + \ RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_after_call) return call
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: # If it's a value type, do the value type translation r = valuetype_type(t, binds=binds) if r is not None: return r if isinstance(t, BaseType): if t.name == BaseTy.Tensor: if mutable: return NamedCType(binds, MutRefCType(BaseCType(tensorT))) else: return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) elif t.name == BaseTy.Scalar: return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) else: raise AssertionError(f"base type should have been value type {t}") elif isinstance(t, OptionalType): if str(t.elem) == 'Tensor': if mutable: return NamedCType(binds, MutRefCType( BaseCType(tensorT))) # TODO: fix this discrepancy else: return NamedCType( binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))) elif str(t.elem) == 'Scalar': return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, OptionalCType(elem.type)) elif isinstance(t, ListType): # TODO: remove these special cases, ArrayRef fallthrough works fine if str(t.elem) == 'int': return NamedCType(binds, BaseCType(intArrayRefT)) elif str(t.elem) == 'Tensor': return NamedCType(binds, BaseCType(tensorListT)) elif str(t.elem) == 'Scalar': return NamedCType(binds, ArrayRefCType(BaseCType(scalarT))) elif str(t.elem) == 'Dimname': return NamedCType(binds, BaseCType(dimnameListT)) elif str(t.elem) == 'Tensor?': return NamedCType( binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, ArrayRefCType(elem.type)) else: raise AssertionError(f"unrecognized type {repr(t)}")
def solve(goal: NamedCType, *, direct: bool) -> str: def direct_solve(goal: NamedCType) -> str: return solve(goal, direct=True) if goal in ctx: # Trivial return ctx[goal] # const & is satisfied with mutable & if isinstance(goal.type, ConstRefCType): try: # WARNING: not strictly decreasing; be careful not # to add a direct conversion that goes satisfies # mutable& with const& return solve(NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct) except UnsatError: pass # mutable & is satisfied with value if isinstance(goal.type, MutRefCType): try: return solve(NamedCType(goal.name, goal.type.elem), direct=direct) except UnsatError: pass if direct: unsat(goal) # For now, all of these rules are mutually exclusive. if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))): memory_format = direct_solve( NamedCType(SpecialArgName.possibly_redundant_memory_format, OptionalCType(BaseCType(memoryFormatT))) ) # No need to join "memory_format" and "options" if the target API takes "options" directly. # Otherwise it will cause the redundant memory_format error. if options_ctype in goal_ctypes: return memory_format try: options = direct_solve(options_ctype) return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})" except UnsatError: return memory_format elif goal == NamedCType("options", BaseCType(tensorOptionsT)): dtype = direct_solve(NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT)))) pin_memory = direct_solve(NamedCType("pin_memory", OptionalCType(BaseCType(boolT)))) device = direct_solve(NamedCType("device", OptionalCType(BaseCType(deviceT)))) layout = direct_solve(NamedCType("layout", OptionalCType(BaseCType(layoutT)))) return f'TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})' elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))): options = direct_solve(options_ctype) return f'optTypeMetaToScalarType({options}.dtype_opt())' elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))): options = direct_solve(options_ctype) return f'{options}.layout_opt()' elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))): options = direct_solve(options_ctype) return f'{options}.device_opt()' elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))): options = direct_solve(options_ctype) return f'{options}.pinned_memory_opt()' unsat(goal)
def translate( bindings: Sequence[Union[Expr, Binding]], goals: Sequence[Union[NamedCType, Binding]], *, method: bool = False ) -> List[Expr]: binding_exprs: List[Expr] = [] for b in bindings: if isinstance(b, Binding): binding_exprs.append(Expr( expr=b.name, type=b.nctype, )) else: binding_exprs.append(b) goal_ctypes: List[NamedCType] = [] for g in goals: if isinstance(g, Binding): goal_ctypes.append(g.nctype) else: goal_ctypes.append(g) # Add all the bindings to the context ctx: Dict[NamedCType, str] = {} for b in binding_exprs: ctx[b.type] = b.expr # While we're at it, do some simple forward inference, looking through # constructors. # TODO: My kingdom for a pattern matcher # https://www.python.org/dev/peps/pep-0634/ # TODO: This could get us in recomputation trouble if b.expr is nontrivial t = b.type if isinstance(t, ConstRefCType) and isinstance(t.elem, OptionalCType) and \ isinstance(t.elem.elem, BaseCType) and str(t.elem.elem.type) == 'at::Tensor': ctx[NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))] = \ f'({b.expr}.has_value() ? *{b.expr} : at::Tensor())' # Add implicit bindings if the generated code is inside a Tensor method if method: ctx[NamedCType("self", MutRefCType(BaseCType(tensorT)))] = "const_cast<Tensor&>(*this)" ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "const_cast<Tensor&>(*this)" # This is better! Byte-for-byte compat # ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this" def unsat(goal: NamedCType) -> NoReturn: ctx_desc = '\n'.join(f" {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items()) raise UnsatError(f''' Failed to synthesize the expression "{goal.cpp_type()} {goal.name}". When I failed, the following bindings were available in the context: {ctx_desc} This probably means there is a missing rule in the rules of tools.codegen.api.translate. Check this module for more information. ''') # A shitty backtracking search implementation. It's shitty because it # doesn't actually do backtracing or search. In particular, if # direct=True, we won't try to do any fancy synthesis, just trivial # conversions (e.g., "T a" is OK for "const T& a"). So all of the # existing rules in this function simply try to solve immediately, # and bail if things don't work out. def solve(goal: NamedCType, *, direct: bool) -> str: def direct_solve(goal: NamedCType) -> str: return solve(goal, direct=True) if goal in ctx: # Trivial return ctx[goal] # const & is satisfied with mutable & if isinstance(goal.type, ConstRefCType): try: # WARNING: not strictly decreasing; be careful not # to add a direct conversion that goes satisfies # mutable& with const& return solve(NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct) except UnsatError: pass # mutable & is satisfied with value if isinstance(goal.type, MutRefCType): try: return solve(NamedCType(goal.name, goal.type.elem), direct=direct) except UnsatError: pass if direct: unsat(goal) # For now, all of these rules are mutually exclusive. if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))): memory_format = direct_solve( NamedCType(SpecialArgName.possibly_redundant_memory_format, OptionalCType(BaseCType(memoryFormatT))) ) # No need to join "memory_format" and "options" if the target API takes "options" directly. # Otherwise it will cause the redundant memory_format error. if options_ctype in goal_ctypes: return memory_format try: options = direct_solve(options_ctype) return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})" except UnsatError: return memory_format elif goal == NamedCType("options", BaseCType(tensorOptionsT)): dtype = direct_solve(NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT)))) pin_memory = direct_solve(NamedCType("pin_memory", OptionalCType(BaseCType(boolT)))) device = direct_solve(NamedCType("device", OptionalCType(BaseCType(deviceT)))) layout = direct_solve(NamedCType("layout", OptionalCType(BaseCType(layoutT)))) return f'TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})' elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))): options = direct_solve(options_ctype) return f'optTypeMetaToScalarType({options}.dtype_opt())' elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))): options = direct_solve(options_ctype) return f'{options}.layout_opt()' elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))): options = direct_solve(options_ctype) return f'{options}.device_opt()' elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))): options = direct_solve(options_ctype) return f'{options}.pinned_memory_opt()' unsat(goal) return [Expr(solve(g, direct=False), g) for g in goal_ctypes]
# other scope); others are more nontrivial and may require packing/unpacking. # Some examples of non-trivial action: # # - Need the "dtype" binding? Well, maybe "dtype" isn't available # in the context, instead, "options" is, and you need to extract # it from there. (Gather) # # - Need the "context" binding? Well, maybe "context" isn't available # in the context, and you need to construct it from "dtype", "device", # etc. (Scatter) # # - Need the "memory_format" binding? Well, actually, it's available # from both "memory_format" and "options", so you had better make sure # they are consistent. (Join) options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT))) class UnsatError(RuntimeError): pass # Given a set of in-scope bindings and a set of target bindings, synthesize # a list of expressions that uses only the in-scope bindings (bindings) that # have all of the types of goals. You may want to use this function if # you're generating code for a function like: # # void f({args}) { # g({exprs}); // g is a different API # } # # and you need to generate "exprs". #
# other scope); others are more nontrivial and may require packing/unpacking. # Some examples of non-trivial action: # # - Need the "dtype" binding? Well, maybe "dtype" isn't available # in the context, instead, "options" is, and you need to extract # it from there. (Gather) # # - Need the "context" binding? Well, maybe "context" isn't available # in the context, and you need to construct it from "dtype", "device", # etc. (Scatter) # # - Need the "memory_format" binding? Well, actually, it's available # from both "memory_format" and "options", so you had better make sure # they are consistent. (Join) options_ctype = ConstRefCType(BaseCType("TensorOptions", "options")) class UnsatError(RuntimeError): pass # Given a set of in-scope bindings and a set of target bindings, synthesize # a list of expressions that uses only the in-scope bindings (bindings) that # have all of the types of goals. You may want to use this function if # you're generating code for a function like: # # void f({args}) { # g({exprs}); // g is a different API # } # # and you need to generate "exprs". #
def process_ir_type( typ: Type) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]: """ This function takes a type from NativeFunctions and converts it for use with lazy tensor codegen. Type conversion for lazy currently consists of (1) changing at::Tensors into lazy::Values (2) wrapping everything in a BaseCType (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef) (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.) There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like' This is incomplete- there are assertions in places that it's expected to need to add more types as the codegen is used with more operators. """ if isinstance(typ, BaseType): if typ.name == BaseTy.Tensor: return BaseCType(valueT) elif typ.name == BaseTy.Scalar: # at::scalar has special handling, # and is wrapped in an lazy::Value just like at::tensor return BaseCType(valueT) elif typ.name == BaseTy.ScalarType: return BaseCType(scalarTypeT) elif typ.name == BaseTy.int: return BaseCType(longT) elif typ.name == BaseTy.bool: return BaseCType(boolT) elif typ.name == BaseTy.float: return BaseCType(doubleT) elif typ.name == BaseTy.str: return BaseCType(stringT) elif typ.name == BaseTy.Device: return BaseCType(deviceT) elif typ.name == BaseTy.Layout: return BaseCType(layoutT) elif typ.name == BaseTy.MemoryFormat: return BaseCType(memoryFormatT) else: raise AssertionError(f"TODO add support for type {repr(typ)}") elif isinstance(typ, OptionalType): return OptionalCType(process_ir_type(typ.elem)) elif isinstance(typ, ListType): if str(typ.elem) == 'Tensor?': # TODO(whc) is this actually correct? or should it use a Vector like above return ListCType(OptionalCType(BaseCType(valueT))) elif str(typ.elem) == 'Tensor': # this is a TensorList which comes in from GetTensorList as a Value return BaseCType(tensorListValueT) else: return VectorCType(process_ir_type(typ.elem)) else: raise AssertionError(f"unrecognized type {repr(typ)}")
} else if (prop.isIntegral(/*includeBool=*/false)) { return PyLong_FromLong(prop.to<int64_t>()); } else if (prop.isBoolean()) { if (prop.to<bool>()) { Py_RETURN_TRUE; } else { Py_RETURN_FALSE; } } else { PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type"); return nullptr; } """ MISC_GETTER_DEFS = { OptionalCType(BaseCType(longT)): (GETTER_DEFINITION_OPT, GETTER_BODY_INT64_T), BaseCType(doubleT): (GETTER_DEFINITION, GETTER_BODY_DOUBLE), OptionalCType(BaseCType(doubleT)): (GETTER_DEFINITION_OPT, GETTER_BODY_DOUBLE), BaseCType(boolT): (GETTER_DEFINITION, GETTER_BODY_BOOL), BaseCType(scalarT): (GETTER_DEFINITION, GETTER_BODY_SCALAR), OptionalCType(BaseCType(scalarT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SCALAR), } # These functions have backwards which cannot be traced, and so must have # their backward functions traced opaquely. # VIEW_FUNCTIONS are not traceable because they use as_strided, which # has an untraceable backwards, see # https://github.com/pytorch/pytorch/issues/4250
def gen_one(self, f: NativeFunction) -> Optional[str]: assert not f.manual_kernel_registration if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f): return None # TODO: Now, there is something interesting going on here. In the code below, # we generate CompositeExplicitAutograd implementations of functional and inplace # based on the out implementation. But in fact, out is definable by # functional too (just not very efficiently), and this is honestly the # MORE likely situation for a backend implementor. How do we pick? # Well, taking a page from Haskell type classes and default methods, # we could conceivably register a circular definition (out in terms # of functional, and functional in terms of out) and just require # someone to implement one or the other. We'd have to do a little bit # of work to not register one of these "weak" definitions unless there # is a strong definition somewhere in the DAG! So it's not implemented yet. if self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd and f.func.kind() is SchemaKind.out: # Never generate a default implementation for out, that's what you # have to define as a backend implementor return None # Note [Direct dispatch bindings] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Signature of the non-dispatched function we'll expose in a header # (e.g., at::cpu::add). We don't generate methods (TODO: do this # when CPUTensor class is a thing); nor do we generate fallback # bindings for manual_cpp_binding functions. cpp_sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=False) # Signature of the wrapper function we'll register to the dispatcher sig = NativeSignature(f.func, prefix="wrapper_") if self.target is Target.NAMESPACED_DECLARATION: result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" if cpp_sig_group.faithful_signature is not None: result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: def generate_defn(cpp_sig: CppSignature) -> str: return f""" {cpp_sig.defn()} {{ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); }} """ result = generate_defn(cpp_sig_group.signature) if cpp_sig_group.faithful_signature is not None: result += generate_defn(cpp_sig_group.faithful_signature) return result elif self.target is Target.ANONYMOUS_DEFINITION: k = f.func.kind() # Construct the body of the wrapper function with signature sig sig_body = [] # We'll use context to keep track of any variables we've brought # into scope while generating code context: List[Union[Binding, Expr]] = list(sig.arguments()) # Initialize the class corresponding to this structured # operator; feeding it the output argument(s) if it is known if self.backend_index.dispatch_key is DispatchKey.Meta: class_name = f"structured_{meta.name(self.g)}_meta_{k.name}" parent_class = f"at::meta::{meta.name(self.g)}" elif self.backend_index.dispatch_key is DispatchKey.CompositeExplicitAutograd: # TODO: dedup this branch class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}" parent_class = f"at::meta::{meta.name(self.g)}" else: metadata = self.backend_index.get_kernel(self.g) assert metadata is not None class_name = f"structured_{metadata.kernel}_{k.name}" parent_class = f"{self.cpp_namespace}::structured_{metadata.kernel}" if is_cuda_dispatch_key(self.backend_index.dispatch_key): device_check_args = itertools.chain( f.func.arguments.out, f.func.arguments.flat_positional ) sig_body.append(RegisterDispatchKey.gen_device_check(f.device_check, list(device_check_args), sig.name())) if k is SchemaKind.functional: sig_body.append(f"{class_name} op;") elif k is SchemaKind.inplace: sig_body.append(f"{class_name} op(self);") elif k is SchemaKind.out: out_args_str = ', '.join(a.name for a in f.func.arguments.out) sig_body.append(f"{class_name} op({out_args_str});") # Translate the input native arguments into structured # arguments for the meta call meta_exprs = ', '.join( e.expr for e in translate( context, structured.meta_arguments(self.g), method=False ) ) sig_body.append(f"op.meta({meta_exprs});") # After running meta, op.outputs_ is guaranteed to be valid; # add it to the context out_args = structured.out_arguments(self.g) for i, out_arg in enumerate(out_args): assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type context.append(Expr( expr=f"op.outputs_[{i}]", # TODO: Stop hardcoding that the output type is a Tensor. Note # that for the codegen here this is fine because outputs_ is # hardcoded to be tensor already type=NamedCType(out_arg.nctype.name, MutRefCType(BaseCType(tensorT))) )) # With the expanded context, do the impl call (if not a meta # function) if self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd: # TODO: https://github.com/pytorch/pytorch/issues/53023 out_sig_group = CppSignatureGroup.from_native_function( self.g.out, method=False, fallback_binding=f.manual_cpp_binding) out_sig = out_sig_group.most_faithful_signature() api_name = out_sig.name() out_exprs = ', '.join( e.expr for e in translate( context, out_sig.arguments(), method=False ) ) # TODO: I think this means structured won't work with method # only functions (but maybe you're saved by faithful? iunno.) # NB: Originally I wrote this as an at::redispatch call, but # I got in trouble because that meant I needed a DispatchKeySet # in the wrapper function, which meant I needed a DispatchKeySet # in the DispatchKeyFunctions declarations, but the defined API # there does NOT permit a dispatch key set. I think you can # probably unwind this by calling some function to do the TLS # fetch and get the DispatchKeySet when you don't have it, but # I didn't do it for this version sig_body.append(f"at::{api_name}({out_exprs});") elif self.backend_index.dispatch_key != DispatchKey.Meta: impl_exprs = ', '.join( e.expr for e in translate( context, structured.impl_arguments(self.g), method=False ) ) sig_body.append(f"op.impl({impl_exprs});") # Destructively return the final tensors # TODO: Do this in translate instead if k is SchemaKind.functional: if len(f.func.returns) == 1: ret_expr = "std::move(op.outputs_[0])" # small optimization else: moved = ', '.join(f"std::move(op.outputs_[{i}])" for i in range(len(f.func.returns))) ret_expr = f"std::make_tuple({moved})" elif k is SchemaKind.inplace: ret_expr = "self" elif k is SchemaKind.out: if len(f.func.returns) == 1: ret_expr = f.func.arguments.out[0].name else: refs = ', '.join(a.name for a in f.func.arguments.out) ret_expr = f"std::forward_as_tuple({refs})" sig_body.append(f"return {ret_expr};") sig_body_str = "\n".join(sig_body) # For an overview of what this template code looks like, see # https://github.com/pytorch/rfcs/pull/9 return f"""\ {self.gen_class( f, k, class_name=class_name, parent_class=parent_class, generate_super=self.g.out.structured_inherits is not None )} {sig.defn()} {{ {sig_body_str} }} """ elif self.target is Target.REGISTRATION: return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));' else: assert_never(self.target) # Silence mypy's "Missing return statement" error return None
def save_var(var: SavedAttribute, is_output: bool) -> None: name = var.nctype.name type = var.nctype.type should_append_getsetdef = True should_append_raw_getsetdef = False if type == BaseCType(tensorT) or type == OptionalCType(BaseCType(tensorT)) or \ type == MutRefCType(OptionalCType(BaseCType(tensorT))) or \ (type == BaseCType(scalarT) and is_output): saved_variables.append(f'SavedVariable {name}_;') release_variables.append(f'{name}_.reset_data();') ptr = 'shared_from_this()' if is_output else '' unpack.append(f'auto {name} = {name}_.unpack({ptr});') getter_definitions.append( GETTER_DEFINITION_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_SAVEDVAR)) getter_definitions.append( GETTER_DEFINITION_RAW_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_RAW_SAVEDVAR)) should_append_raw_getsetdef = True elif type == BaseCType(tensorListT): saved_variables.append(f'std::vector<SavedVariable> {name}_;') saved_variables.append(f'bool {name}_released_ = false;') # Just clear() is sufficient, we don't need to loop and clear each variable. # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. release_variables.append(f'{name}_.clear();') release_variables.append(f'{name}_released_ = true;') unpack.append(f'auto {name} = unpack_list({name}_);') asserts.append( f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);') getter_definitions.append( GETTER_DEFINITION_VEC_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR)) getter_definitions.append( GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR)) should_append_raw_getsetdef = True elif type == ListCType(OptionalCType(BaseCType(tensorT))): saved_variables.append(f'std::vector<SavedVariable> {name}_;') saved_variables.append(f'bool {name}_released_ = false;') # Just clear() is sufficient, we don't need to loop and clear each variable. # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. release_variables.append(f'{name}_.clear();') release_variables.append(f'{name}_released_ = true;') unpack.append(f'auto {name} = unpack_opt_list({name}_);') asserts.append( f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);') getter_definitions.append( GETTER_DEFINITION_VEC_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR)) getter_definitions.append( GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR)) should_append_raw_getsetdef = True elif type == BaseCType(intArrayRefT): saved_variables.append(f'std::vector<int64_t> {name};') getter_definitions.append( GETTER_DEFINITION.substitute(op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG)) elif type == BaseCType(optionalIntArrayRefT): saved_variables.append(f'c10::OptionalArray<int64_t> {name};') getter_definitions.append( GETTER_DEFINITION_OPT_ARRAYREF.substitute( op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG)) elif type == OptionalCType(BaseCType(intArrayRefT)): saved_variables.append(f'c10::OptionalArray<int64_t> {name};') getter_definitions.append( GETTER_DEFINITION_OPT_ARRAYREF.substitute( op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG)) elif type == OptionalCType(ArrayRefCType(BaseCType(doubleT))): saved_variables.append(f'c10::OptionalArray<double> {name};') getter_definitions.append( GETTER_DEFINITION_OPT_ARRAYREF.substitute( op=info.op, name=name, body=GETTER_BODY_ARRAYREF_DOUBLE)) elif type == BaseCType(longT): saved_variables.append(f'{type.cpp_type()} {name} = 0;') getter_definitions.append( GETTER_DEFINITION.substitute(op=info.op, name=name, body=GETTER_BODY_INT64_T)) elif type == BaseCType(stringT): saved_variables.append(f'std::string {name};') getter_definitions.append( GETTER_DEFINITION.substitute(op=info.op, name=name, body=GETTER_BODY_STRING)) elif type == OptionalCType(BaseCType(stringT)): saved_variables.append(f'c10::optional<std::string> {name};') getter_definitions.append( GETTER_DEFINITION_OPT.substitute(op=info.op, name=name, body=GETTER_BODY_STRING)) else: saved_variables.append(f'{type.cpp_type()} {name};') if type in MISC_GETTER_DEFS: getter_def, body = MISC_GETTER_DEFS[type] getter_definitions.append( getter_def.substitute(op=info.op, name=name, body=body)) else: # Types we don't expose python bindings to yet: # TypeAndSize, at::ScalarType, TensorOptions, TensorGeometry, # std::vector<std::vector<int64_t>>, std::vector<at::ScalarType> should_append_getsetdef = False if should_append_getsetdef: py_getsetdef_structs.append( PY_GETSETDEF_STRUCT.substitute(op=info.op, name=name)) if should_append_raw_getsetdef: py_getsetdef_structs.append( PY_RAW_GETSETDEF_STRUCT.substitute(op=info.op, name=name))
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType: # If it's a value type, do the value type translation r = valuetype_type(t, binds=binds) if r is not None: return r if isinstance(t, BaseType): if t.name == BaseTy.Tensor: if mutable: return MutRefCType(BaseCType('Tensor', binds)) else: return ConstRefCType(BaseCType('Tensor', binds)) elif t.name == BaseTy.Scalar: return ConstRefCType(BaseCType('Scalar', binds)) else: raise AssertionError(f"base type should have been value type {t}") elif isinstance(t, OptionalType): if str(t.elem) == 'Tensor': if mutable: return MutRefCType(BaseCType( 'Tensor', binds)) # TODO: fix this discrepancy else: return ConstRefCType(OptionalCType(BaseCType('Tensor', binds))) elif str(t.elem) == 'Scalar': return ConstRefCType(OptionalCType(BaseCType('Scalar', binds))) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return OptionalCType(elem) elif isinstance(t, ListType): # TODO: remove these special cases, ArrayRef fallthrough works fine # NB: CType throws away ArrayRef structure because it is not currently # relevant in translation. When it becomes relevant, need to add back if str(t.elem) == 'int': return BaseCType("IntArrayRef", binds) elif str(t.elem) == 'Tensor': return BaseCType("TensorList", binds) elif str(t.elem) == 'Scalar': return BaseCType("ArrayRef<Scalar>", binds) elif str(t.elem) == 'Dimname': return BaseCType("DimnameList", binds) elif str(t.elem) == 'Tensor?': return ConstRefCType( BaseCType("c10::List<c10::optional<Tensor>>", binds)) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) # TODO: explicitly qualify namespace here return BaseCType(f"ArrayRef<{elem.cpp_type()}>", binds) else: raise AssertionError(f"unrecognized type {repr(t)}")