def process_ir_type( typ: Type, properties: "LazyIrProperties" ) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]: """ This function takes a type from NativeFunctions and converts it for use with lazy tensor codegen. Type conversion for lazy currently consists of (1) changing at::Tensors into lazy::Values (2) wrapping everything in a BaseCType (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef) (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.) There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like' This is incomplete- there are assertions in places that it's expected to need to add more types as the codegen is used with more operators. """ if isinstance(typ, BaseType): if typ.name == BaseTy.Tensor: return BaseCType(getValueT()) elif typ.name == BaseTy.Scalar: if properties.TreatScalarsAsConstants: return BaseCType(scalarT) # at::scalar has special handling, # and is wrapped in an lazy::Value just like at::tensor return BaseCType(getValueT()) elif typ.name == BaseTy.ScalarType: return BaseCType(scalarTypeT) elif typ.name == BaseTy.int: return BaseCType(longT) elif typ.name == BaseTy.SymInt: return BaseCType(getValueT()) elif typ.name == BaseTy.bool: return BaseCType(boolT) elif typ.name == BaseTy.float: return BaseCType(doubleT) elif typ.name == BaseTy.str: return BaseCType(stringT) elif typ.name == BaseTy.Device: return BaseCType(deviceT) elif typ.name == BaseTy.Layout: return BaseCType(layoutT) elif typ.name == BaseTy.MemoryFormat: return BaseCType(memoryFormatT) else: raise AssertionError(f"TODO add support for type {repr(typ)}") elif isinstance(typ, OptionalType): return OptionalCType(process_ir_type(typ.elem, properties)) elif isinstance(typ, ListType): if str(typ.elem) == "Tensor?": # TODO(whc) is this actually correct? or should it use a Vector like above return ListCType(OptionalCType(BaseCType(getValueT()))) elif str(typ.elem) == "Tensor": # this is a TensorList which comes in from GetTensorList as a Value return BaseCType(tensorListValueT) else: return VectorCType(process_ir_type(typ.elem, properties)) else: raise AssertionError(f"unrecognized type {repr(typ)}")
def returntype_type(t: Type, *, mutable: bool) -> CType: # placeholder is ignored r = valuetype_type(t, binds="__placeholder__") if r is not None: return r.type if isinstance(t, BaseType): if t.name == BaseTy.Tensor: if mutable: if local.use_const_ref_for_mutable_tensors(): return ConstRefCType(BaseCType(tensorT)) else: return MutRefCType(BaseCType(tensorT)) else: # Note [Tensor Copy Returns] # Currently, we use "Argument.is_write" to determine # whether or not Tensor return types should be copies or references. # If that ever changes, take a look at other locations of this note! return BaseCType(tensorT) elif t.name == BaseTy.Scalar: return BaseCType(scalarT) elif isinstance(t, ListType): assert ( not mutable ), "Native functions should never return a mutable tensor list. They should return void." elem = returntype_type(t.elem, mutable=False) assert t.size is None, f"fixed size list returns not supported: {t}" return VectorCType(elem) raise AssertionError(f"unrecognized return type {t}")
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName, remove_non_owning_ref_types: bool = False) -> NamedCType: # If it's a value type, do the value type translation r = valuetype_type(t, binds=binds, remove_non_owning_ref_types=remove_non_owning_ref_types) if r is not None: return r if isinstance(t, BaseType): if t.name == BaseTy.Tensor: if mutable and not local.use_const_ref_for_mutable_tensors(): return NamedCType(binds, MutRefCType(BaseCType(tensorT))) else: return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) elif t.name == BaseTy.Scalar: return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) else: raise AssertionError(f"base type should have been value type {t}") elif isinstance(t, OptionalType): if str(t.elem) == "Tensor": if mutable and not local.use_const_ref_for_mutable_tensors(): return NamedCType(binds, MutRefCType( BaseCType(tensorT))) # TODO: fix this discrepancy else: return NamedCType( binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))) elif str(t.elem) == "Scalar": return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int": return NamedCType(binds, BaseCType(optionalIntArrayRefT)) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, OptionalCType(elem.type)) elif isinstance(t, ListType): # TODO: remove these special cases, ArrayRef fallthrough works fine if str(t.elem) == "int": if remove_non_owning_ref_types: return NamedCType(binds, VectorCType(BaseCType(longT))) else: return NamedCType(binds, BaseCType(intArrayRefT)) elif str(t.elem) == "Tensor": return NamedCType(binds, BaseCType(tensorListT)) elif str(t.elem) == "Scalar": return NamedCType(binds, ArrayRefCType(BaseCType(scalarT))) elif str(t.elem) == "SymInt": return NamedCType(binds, BaseCType(symIntArrayRefT)) elif str(t.elem) == "Dimname": return NamedCType(binds, BaseCType(dimnameListT)) elif str(t.elem) == "Tensor?": return NamedCType( binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, ArrayRefCType(elem.type)) else: raise AssertionError(f"unrecognized type {repr(t)}")
def isValueType(typ: CType, properties: "Optional[LazyIrProperties]" = None) -> bool: """ Given a type, determine if it is a Value-like type. This is equivalent to being Tensor-like, but assumes the type has already been transformed. """ if isinstance(typ, BaseCType): # I am regretting my naming conventions, but now we are wrapping at::scalar in # lazy value, while preserving other 'scalar' types as scalars in the IR treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants return (typ.type == getValueT() or (typ.type == scalarT and not treat_scalars_as_constants) or typ.type == SymIntT) elif typ == VectorCType(BaseCType(SymIntT)): # TODO: report True for this return False elif isinstance(typ, (OptionalCType, ListCType, VectorCType)): return isValueType(typ.elem, properties) return False
def solve(goal: NamedCType, *, direct: bool) -> str: def direct_solve(goal: NamedCType) -> str: return solve(goal, direct=True) if goal in ctx: # Trivial return ctx[goal] # const & is satisfied with mutable & if isinstance(goal.type, ConstRefCType): try: # WARNING: not strictly decreasing; be careful not # to add a direct conversion that goes satisfies # mutable& with const& return solve(NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct) except UnsatError: pass # mutable & is satisfied with value if isinstance(goal.type, MutRefCType): try: return solve(NamedCType(goal.name, goal.type.elem), direct=direct) except UnsatError: pass if direct: unsat(goal) # For now, all of these rules are mutually exclusive. if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))): memory_format = direct_solve( NamedCType( SpecialArgName.possibly_redundant_memory_format, OptionalCType(BaseCType(memoryFormatT)), )) # No need to join "memory_format" and "options" if the target API takes "options" directly. # Otherwise it will cause the redundant memory_format error. if options_ctype in goal_ctypes: return memory_format try: options = direct_solve(options_ctype) return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})" except UnsatError: return memory_format elif goal == NamedCType("options", BaseCType(tensorOptionsT)): dtype = direct_solve( NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT)))) pin_memory = direct_solve( NamedCType("pin_memory", OptionalCType(BaseCType(boolT)))) device = direct_solve( NamedCType("device", OptionalCType(BaseCType(deviceT)))) layout = direct_solve( NamedCType("layout", OptionalCType(BaseCType(layoutT)))) return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})" elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))): try: options = direct_solve(options_ctype) return f"optTypeMetaToScalarType({options}.dtype_opt())" except UnsatError: out_tensor = direct_solve(out_tensor_ctype) return f"{out_tensor}.scalar_type()" elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))): try: options = direct_solve(options_ctype) return f"{options}.layout_opt()" except UnsatError: out_tensor = direct_solve(out_tensor_ctype) return f"{out_tensor}.layout()" elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))): try: options = direct_solve(options_ctype) return f"{options}.device_opt()" except UnsatError: out_tensor = direct_solve(out_tensor_ctype) return f"{out_tensor}.device()" elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))): try: options = direct_solve(options_ctype) return f"{options}.pinned_memory_opt()" except UnsatError: # If we're calling a factory op from its out= variant, # We don't actually care about the value of pin_memory. out_tensor = direct_solve(out_tensor_ctype) return "c10::nullopt" # We can always do translations from value types to reference types, like vector<int> -> IntArrayRef elif goal.type == BaseCType(intArrayRefT): try: return direct_solve(NamedCType(goal.name, longVec_ctype)) except UnsatError: # We can also go SymIntArrayRef -> IntArrayRef symIntArrayRef_type = direct_solve( NamedCType(goal.name, BaseCType(symIntArrayRefT))) return f"c10::asIntArrayRefSlow({symIntArrayRef_type})" elif goal.type == BaseCType(symIntArrayRefT): return direct_solve(NamedCType(goal.name, longSymVec_ctype)) elif goal.type == BaseCType(longT): symInt_type = direct_solve( NamedCType(goal.name, BaseCType(SymIntT))) return f"{symInt_type}.expectInt()" elif goal.type == BaseCType(optionalIntArrayRefT): return direct_solve(NamedCType(goal.name, optionalLongVec_ctype)) elif goal.type == BaseCType(optionalScalarRefT): return direct_solve(NamedCType(goal.name, optionalScalar_ctype)) elif goal.type == BaseCType(optionalTensorRefT): return direct_solve(NamedCType(goal.name, optionalTensor_ctype)) # Note [translation from C++ reference to value types] # The below cases are all for when we have an argument with a reference type, # and a corresponding goal with a value type. # These are needed when we populate the inputs to a lambda capture and we need # to guarantee the lifetime of each captured argument. # We guard it with an explicit kwarg because converting to a value type is expensive # (O(n)) to convert from IntArrayRef to vector<int>), # so the caller of translate() should be explicit that they need it. if allow_expensive_conversions: if goal.type == VectorCType(BaseCType(longT)): intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT)) argname = direct_solve(intArrayRef_ctype) return f"{argname}.vec()" if goal.type == VectorCType(BaseCType(SymIntT)): symIntArrayRef_ctype = NamedCType(goal.name, BaseCType(symIntArrayRefT)) argname = direct_solve(symIntArrayRef_ctype) return f"{argname}.vec()" elif goal.type == OptionalCType(VectorCType(BaseCType(longT))): optionalIntArrayRef_ctype = NamedCType( goal.name, BaseCType(optionalIntArrayRefT)) argname = direct_solve(optionalIntArrayRef_ctype) return f"{argname}.has_value() ? c10::make_optional({argname}->vec()) : c10::nullopt" elif goal.type == OptionalCType(BaseCType(scalarT)): optionalScalarRef_ctype = NamedCType( goal.name, BaseCType(optionalScalarRefT)) argname = direct_solve(optionalScalarRef_ctype) return f"{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt" elif goal.type == OptionalCType(BaseCType(scalarT)): optionalTensorRef_ctype = NamedCType( goal.name, BaseCType(optionalTensorRefT)) argname = direct_solve(optionalTensorRef_ctype) return f"{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt" # Technically, we also need to handle cases of C++ containers holding reference types. # But there currently aren't any ops that require lambda capture codegen # With arguments like std::vector<IntArrayRef>. # If that changes, we'll have to add the translation here. # We allow const casting on tensors, since const-correctness is a bit broken for at::Tensor. # We could probably generalize this to non-tensor types too. if goal.type == MutRefCType(BaseCType(tensorT)): const_ref_tensor_ctype = NamedCType( goal.name, ConstRefCType(BaseCType(tensorT))) argname = direct_solve(const_ref_tensor_ctype) return f"const_cast<Tensor&>({argname})" unsat(goal)
# in the context, instead, "options" is, and you need to extract # it from there. (Gather) # # - Need the "context" binding? Well, maybe "context" isn't available # in the context, and you need to construct it from "dtype", "device", # etc. (Scatter) # # - Need the "memory_format" binding? Well, actually, it's available # from both "memory_format" and "options", so you had better make sure # they are consistent. (Join) options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT))) out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT))) longVec_ctype = VectorCType(BaseCType(longT)) longSymVec_ctype = VectorCType(BaseCType(SymIntT)) optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT))) optionalScalar_ctype = OptionalCType(BaseCType(scalarT)) optionalTensor_ctype = OptionalCType(BaseCType(tensorT)) class UnsatError(RuntimeError): pass # Given a set of in-scope bindings and a set of target bindings, synthesize # a list of expressions that uses only the in-scope bindings (bindings) that # have all of the types of goals. You may want to use this function if # you're generating code for a function like: #
def saved_variables( formula: str, nctypes: List[NamedCType], var_names: Tuple[str, ...], ) -> Tuple[str, Tuple[SavedAttribute, ...]]: def stride_expr(name: str) -> str: assert var_names == (name, ), ( 'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor ' 'that ".strides()" is being called on.') return f'strides_or_error({name}, "{name}")' REPLACEMENTS: List[Tuple[str, Dict[str, Any]]] = [ # replace self.sizes() with self_sizes ( r"{}.sizes\(\)", { "suffix": "_sizes", "nctype": lambda name: NamedCType(name, BaseCType(intArrayRefT)), }, ), # replace self.sym_sizes() with self_sym_sizes ( r"{}.sym_sizes\(\)", { "suffix": "_sym_sizes", "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)), }, ), # replace self->sizes() with self_sizes_opt ( r"{}->sizes\(\)", { "suffix": "_sizes_opt", "nctype": lambda name: NamedCType( name, OptionalCType(BaseCType(intArrayRefT))), "expr": lambda name: f"{name}.has_value() ? c10::optional<IntArrayRef>({name}->sizes()) : c10::nullopt", }, ), # replace self.options() with self_options ( r"{}.options\(\)", { "suffix": "_options", "nctype": lambda name: NamedCType(name, BaseCType(tensorOptionsT)), }, ), # replace zeros_like(self) with self_info ( r"zeros_like\({}\)", { "suffix": "_info", "nctype": lambda name: NamedCType(name, BaseCType(typeAndSizeT)), "expr": lambda name: name, # at save-time "res": lambda name: name + "_info.zeros()", # at eval-time }, ), # replace self.size(2) with self_size_2 ( r"{}.size\((\w+)\)", { "suffix": lambda m: "_argsize_{}".format(*m.groups()), "nctype": lambda name: NamedCType(name, BaseCType(longT)), }, ), # replace self.numel() with self_numel ( r"{}.numel\(\)", { "suffix": "_numel", "nctype": lambda name: NamedCType(name, BaseCType(longT)), }, ), # replace to_args_sizes(self) with self_args_sizes ( r"to_args_sizes\({}\)", { "suffix": "_args_sizes", "nctype": lambda name: NamedCType( name, VectorCType(VectorCType(BaseCType(longT)))), }, ), # replace to_args_scalartypes(self) with self_args_scalartypes ( r"to_args_scalartypes\({}\)", { "suffix": "_args_scalartypes", "nctype": lambda name: NamedCType(name, VectorCType(BaseCType(scalarTypeT))), }, ), # replace TensorGeometry(self) with self_geometry ( r"TensorGeometry\({}\)", { "suffix": "_geometry", "nctype": lambda name: NamedCType(name, BaseCType(tensorGeometryT)), }, ), ( r"{}.scalar_type\(\)", { "suffix": "_scalar_type", "nctype": lambda name: NamedCType(name, BaseCType(scalarTypeT)), }, ), # replace self.dim() with self_dim ( r"{}.dim\(\)", { "suffix": "_dim", "nctype": lambda name: NamedCType(name, BaseCType(longT)), }, ), # replace self.strides() with self_strides ( r"{}.strides\(\)", { "suffix": "_strides", "nctype": lambda name: NamedCType(name, BaseCType(intArrayRefT)), "expr": stride_expr, }, ), # replace self.layout() with self_layout ( r"{}.layout\(\)", { "suffix": "_layout", "nctype": lambda name: NamedCType(name, BaseCType(layoutT)), }, ), # replace self.is_conj() with self_conjugate ( r"{}.is_conj\(\)", { "suffix": "_conjugate", "nctype": lambda name: NamedCType(name, BaseCType(boolT)), }, ), ] # find which arguments need to be saved saved: List[SavedAttribute] = [] for nctype in nctypes: name = (nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name) # First search the formula for expressions which can be evaluated # when the autograd Function is created to avoid saving variables for regex, info in REPLACEMENTS: def repl(m: Match[str]) -> str: suffix: str = (info["suffix"](m) if callable(info["suffix"]) else info["suffix"]) expr: str = info["expr"](name) if "expr" in info else m.group( 0) saved.append( SavedAttribute( nctype=info["nctype"](name + suffix), expr=expr, )) if "res" in info: replacement: str = info["res"](name) return replacement return name + suffix formula = re.sub(regex.format(name), repl, formula) # c10::optional<std::string> types stored in Backward nodes must be # converted to c10::optional<c10::string_view> before being passed into # the backward function if nctype.type == OptionalCType(BaseCType(stringT)): formula = re.sub( rf"\b{name}\b", f"{name}.has_value() ? c10::optional<c10::string_view>({name}.value()) : c10::nullopt", formula, ) # Find any variables which remain in the formula and save them if re.search(IDENT_REGEX.format(name), formula): saved.append(SavedAttribute( nctype=nctype, expr=name, )) return formula, tuple(saved)
def get_owning_type(t: CType) -> Tuple[CType, Callable[[str], str]]: if t == BaseCType(tensorListT): return VectorCType(BaseCType(tensorT)), lambda x: f"{x}.vec()" # There are technically other non-owning types out there (like IntArrayRef), # but functionalization only actually cares about the ones involving tensors. return t, lambda x: x
# # - Need the "dtype" binding? Well, maybe "dtype" isn't available # in the context, instead, "options" is, and you need to extract # it from there. (Gather) # # - Need the "context" binding? Well, maybe "context" isn't available # in the context, and you need to construct it from "dtype", "device", # etc. (Scatter) # # - Need the "memory_format" binding? Well, actually, it's available # from both "memory_format" and "options", so you had better make sure # they are consistent. (Join) options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT))) longVec_ctype = VectorCType(BaseCType(longT)) optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT))) optionalScalar_ctype = OptionalCType(BaseCType(scalarT)) optionalTensor_ctype = OptionalCType(BaseCType(tensorT)) class UnsatError(RuntimeError): pass # Given a set of in-scope bindings and a set of target bindings, synthesize # a list of expressions that uses only the in-scope bindings (bindings) that # have all of the types of goals. You may want to use this function if # you're generating code for a function like: # # void f({args}) {