def returntype_type(t: Type, *, mutable: bool) -> CType: # placeholder is ignored r = valuetype_type(t, binds="__placeholder__") if r is not None: return r.type if isinstance(t, BaseType): if t.name == BaseTy.Tensor: if mutable: if local.use_const_ref_for_mutable_tensors(): return ConstRefCType(BaseCType(tensorT)) else: return MutRefCType(BaseCType(tensorT)) else: # Note [Tensor Copy Returns] # Currently, we use "Argument.is_write" to determine # whether or not Tensor return types should be copies or references. # If that ever changes, take a look at other locations of this note! return BaseCType(tensorT) elif t.name == BaseTy.Scalar: return BaseCType(scalarT) elif isinstance(t, ListType): assert ( not mutable ), "Native functions should never return a mutable tensor list. They should return void." elem = returntype_type(t.elem, mutable=False) assert t.size is None, f"fixed size list returns not supported: {t}" return VectorCType(elem) raise AssertionError(f"unrecognized return type {t}")
def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments], *, is_out: bool) -> List[Binding]: # Ideally, we NEVER default native functions. However, there are a number # of functions that call native:: directly and rely on the defaulting # existing. So for BC, we generate defaults for non-out variants (but not # for out variants, where it is impossible to generate an appropriate # default) should_default = not is_out if isinstance(a, Argument): default: Optional[str] = None if should_default and a.default is not None: default = cpp.default_expr(a.default, a.type) return [ Binding( nctype=argument_type(a, binds=a.name), name=a.name, default=default, argument=a, ) ] elif isinstance(a, SelfArgument): # Erase SelfArgument from the distinction return argument(a.argument, is_out=is_out) elif isinstance(a, TensorOptionsArguments): default = None if should_default: default = "{}" # TODO: Not sure why the arguments assigned here are for # TensorOptionsArguments and not the constituent pieces. It seems # to matter return [ Binding( nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))), name="dtype", default=default, argument=a, ), Binding( nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))), name="layout", default=default, argument=a, ), Binding( nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))), name="device", default=default, argument=a, ), Binding( nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))), name="pin_memory", default=default, argument=a, ), ] else: assert_never(a)
def valuetype_type( t: Type, *, binds: ArgName, remove_non_owning_ref_types: bool = False ) -> Optional[NamedCType]: if isinstance(t, BaseType): if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar: return None if remove_non_owning_ref_types: if t.name == BaseTy.str: raise AssertionError( "string ref->value conversion: not implemented yet" ) # All other BaseType currently map directly to BaseCppTypes. return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name])) elif isinstance(t, OptionalType): elem = valuetype_type(t.elem, binds=binds) if elem is None: return None return NamedCType(binds, OptionalCType(elem.type)) elif isinstance(t, ListType): if str(t.elem) == "bool": assert t.size is not None return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size)) else: return None else: raise AssertionError(f"unrecognized type {repr(t)}")
def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType: r = cpp.valuetype_type(t, binds=binds, symint=False) if r is not None: return r if t == BaseType(BaseTy.Scalar): return NamedCType(binds, BaseCType(opmath_type(scalar_t))) elif t == BaseType(BaseTy.Tensor): return NamedCType(binds, BaseCType(opmath_type(scalar_t))) else: raise AssertionError(f"unrecognized type {repr(t)}")
def emit_view_lambda(f: NativeFunction, unpacked_bindings: List[Binding]) -> str: """Generate an additional lambda function to recover views in backward when as_strided is not supported. See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details.""" input_base = "input_base" replay_view_func = "" updated_unpacked_args: List[str] = [] known_view_arg_simple_types: List[CType] = [ BaseCType(longT), OptionalCType(BaseCType(longT)), BaseCType(boolT), BaseCType(intArrayRefT), BaseCType(symIntArrayRefT), ] for unpacked_binding in unpacked_bindings: arg, arg_type = unpacked_binding.name, unpacked_binding.nctype.type if arg == "self_": updated_unpacked_args.append(input_base) continue if arg_type not in known_view_arg_simple_types: known_types_str = ", ".join([str(t) for t in known_view_arg_simple_types]) raise TypeError( f"You are adding an {arg_type} {arg} argument to op {cpp.name(f.func)} in addition to known types: " f"{known_types_str}. Please update the list or materialize it so that it can be closed " "over by value, also add a test in pytorch/xla/test/test_operations.py where this code " "is exercised." ) if arg_type == BaseCType(intArrayRefT) or arg_type == BaseCType( symIntArrayRefT ): # It's not safe to close over IntArrayRef by value, since this is a # reference type, so materialize a vector to close over by value arg_vec = arg + "_vec" replay_view_func += ARRAYREF_TO_VEC.substitute(arg=arg, vec=arg_vec) updated_unpacked_args.append(arg_vec) elif arg_type == OptionalCType(BaseCType(longT)): # Materialize int64_t? to int64_t arg_value = arg + "_val" replay_view_func += OPTIONAL_TO_VAL.substitute( arg=arg, val=arg_value, default="0" ) updated_unpacked_args.append(arg_value) else: updated_unpacked_args.append(arg) replay_view_call = emit_view_call(f, input_base, updated_unpacked_args) replay_view_func += REPLAY_VIEW_LAMBDA_FUNC.substitute( input_base=input_base, replay_view_call=replay_view_call ) is_view_with_metadata_change = ( "true" if cpp.name(f.func) in VIEW_FUNCTIONS_WITH_METADATA_CHANGE else "false" ) return SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE.substitute( is_view_with_metadata_change=is_view_with_metadata_change, replay_view_func=replay_view_func, )
def returns_type(rs: Sequence[Return]) -> CType: if len(rs) == 0: return BaseCType(voidT) elif len(rs) == 1: return return_type(rs[0]) else: return TupleCType([return_type(r) for r in rs])
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: if str(t) == "Tensor?": tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT)) if mutable and not local.use_const_ref_for_mutable_tensors(): return NamedCType(binds, MutRefCType(tensor_type)) else: return NamedCType(binds, ConstRefCType(tensor_type)) elif str(t) == "Tensor?[]": return NamedCType( binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))) elif str(t) == "Scalar": return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) elif str(t) == "Scalar?": return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) return cpp.argumenttype_type(t, mutable=mutable, binds=binds)
def ufunctor_apply_type( t: Type, *, binds: ArgName, scalar_t: BaseCppType ) -> NamedCType: if t == BaseType(BaseTy.Tensor): return NamedCType(binds, BaseCType(scalar_t)) else: raise AssertionError(f"unrecognized type {repr(t)}")
def returns_type(func: FunctionSchema) -> CType: # Assertion: all view ops return tensor-like outputs assert len(func.returns) >= 1 for ret in func.returns: assert ret.type.is_tensor_like() # However, the return type of the lambda is always an individual tensor. # For multi-tensor outputs, each tensor needs to be tracked individually. return BaseCType(tensorT)
def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str: stub_sig = StubSignature(g) # Reindex the ufunc by dtypes; processing generic/scalaronly as well loops = g.out.ufunc_inner_loop ufunc_sigs: Dict[ScalarType, Dict[UfuncKey, UfuncSignature]] = {} for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]: lks = [] # ORDER MATTERS: this specifies overriding precedence if k in loops: # should happen rarely lks.append(k) if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar: lks.append(UfuncKey.ScalarOnly) if UfuncKey.Generic in loops: lks.append(UfuncKey.Generic) # TODO: don't hardcode ufunc:: namespace here, should be centralized smh for lk in lks: for dtype in loops[lk].supported_dtypes: compute_t: CType if k is UfuncKey.CPUScalar: compute_t = BaseCType(scalar_t) elif k is UfuncKey.CPUVector: compute_t = VectorizedCType(BaseCType(scalar_t)) else: raise AssertionError() inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {}) if k not in inner_ufunc_sigs: inner_ufunc_sigs[k] = UfuncSignature( g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t ) # Build the conditionals dtype_cases = [] for dtype, inner_ufunc_sigs in ufunc_sigs.items(): dtype_cases.append( f""" AT_DISPATCH_CASE(at::ScalarType::{dtype}, [&]() {{ {compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())} }} ) """ ) dtype_cases_str = "\n".join(dtype_cases) return f"""
def dispatchstub_type(t: Type, *, binds: ArgName) -> Optional[NamedCType]: r = cpp.valuetype_type(t, binds=binds) if r is not None: return r if t == BaseType(BaseTy.Scalar): return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) elif t == BaseType(BaseTy.Tensor): return None else: raise AssertionError(f"unrecognized type {repr(t)}")
def get_device(self, func: NativeFunction, schema: LazyIrSchema) -> str: value_args = schema.filtered_args(values=True, scalars=False) scalar_args = schema.filtered_args(values=False, scalars=True) value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar] optional_device = OptionalCType(BaseCType(deviceT)) optional_devices = [ a.name for a in scalar_args if a.lazy_type == optional_device ] assert ( len(value_types_names) > 0 or len(optional_devices) > 0 ), "Expected at least one Value or Device type" get_device_str = ( f"{self.get_device_fn}({', '.join(value_types_names + optional_devices)})" ) return f"""auto common_device = {get_device_str};
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: # If it's a value type, do the value type translation r = cpp.valuetype_type(t, binds=binds) if r is not None: return r if isinstance(t, BaseType): if t.name == BaseTy.Tensor: return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) elif t.name == BaseTy.Scalar: return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) else: raise AssertionError(f"base type should have been value type {t}") elif isinstance(t, OptionalType): if t.elem == BaseType(BaseTy.Tensor): return NamedCType(binds, BaseCType(optionalTensorRefT)) elif t.elem == BaseType(BaseTy.Scalar): return NamedCType(binds, BaseCType(optionalScalarRefT)) elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int": return NamedCType(binds, BaseCType(optionalIntArrayRefT)) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, OptionalCType(elem.type)) elif isinstance(t, ListType): if t.elem == BaseType(BaseTy.Tensor): return NamedCType(binds, BaseCType(iTensorListRefT)) elif t.elem == OptionalType(BaseType(BaseTy.Tensor)): return NamedCType(binds, BaseCType(iOptTensorListRefT)) # TODO: delete these special cases; see torchgen.api.cpp--these # must be changed in tandem, but there are problems; see # https://github.com/pytorch/pytorch/pull/51485 elif str(t.elem) == "int": return NamedCType(binds, BaseCType(intArrayRefT)) elif str(t.elem) == "Dimname": return NamedCType(binds, BaseCType(dimnameListT)) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, ArrayRefCType(elem.type)) else: raise AssertionError(f"unrecognized type {repr(t)}")
def compute_ufunc_cuda_dtype_body( g: NativeFunctionsGroup, dtype: ScalarType, inner_loops: Dict[UfuncKey, UfunctorSignature], parent_ctx: Sequence[Binding], ) -> str: body = "using opmath_t = at::opmath_type<scalar_t>;" body += "if (false) {}\n" # for ease of codegen for config in BinaryScalarSpecializationConfigs: if config.ufunc_key not in inner_loops: continue ufunctor_sig = inner_loops[config.ufunc_key] scalar_idx = config.scalar_idx + 1 # Make a copy and at the same time widen the type (not permissible # without copy; we don't want to mutate the input argument anyway) ctx: List[Union[Expr, Binding]] = list(parent_ctx) ctx.append( Expr( expr=f"iter.scalar_value<opmath_t>({scalar_idx})", type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)), ) ) ufunctor_ctor_exprs_str = ", ".join( a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor) ) # NB: ufunctor must be allocated before iter.remove_operand is called, # as it relies on iter body += f"""\ else if (iter.is_cpu_scalar({scalar_idx})) {{ {ufunctor_sig.name}<scalar_t> ufunctor({ufunctor_ctor_exprs_str}); iter.remove_operand({scalar_idx}); gpu_kernel(iter, ufunctor); }}""" ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor] ufunctor_ctor_exprs_str = ", ".join( a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor) ) body += f""" else {{ gpu_kernel(iter, {ufunctor_sig.name}<scalar_t>({ufunctor_ctor_exprs_str})); }} """ return body
def isValueType(typ: CType, properties: "Optional[LazyIrProperties]" = None) -> bool: """ Given a type, determine if it is a Value-like type. This is equivalent to being Tensor-like, but assumes the type has already been transformed. """ if isinstance(typ, BaseCType): # I am regretting my naming conventions, but now we are wrapping at::scalar in # lazy value, while preserving other 'scalar' types as scalars in the IR treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants return (typ.type == getValueT() or (typ.type == scalarT and not treat_scalars_as_constants) or typ.type == SymIntT) elif typ == VectorCType(BaseCType(SymIntT)): # TODO: report True for this return False elif isinstance(typ, (OptionalCType, ListCType, VectorCType)): return isValueType(typ.elem, properties) return False
def compute_ufunc_cpu_dtype_body( g: NativeFunctionsGroup, dtype: ScalarType, inner_loops: Dict[UfuncKey, UfuncSignature], parent_ctx: Sequence[Binding], ) -> str: assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}" assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector} scalar_loop = inner_loops[UfuncKey.CPUScalar] vec_loop = None if UfuncKey.CPUVector in inner_loops: vec_loop = inner_loops[UfuncKey.CPUVector] # NB: We DON'T use translate here, because translate is # incapable of CSE'ing the scalar accesses in case it is also # used by Vectorized; also, the unpacking here is very simple # and only affects Scalar; everything else is implicitly captured # by the lambda # Setup scalar in scope body = [] ctx = [] for b in parent_ctx: if isinstance(b.argument, Argument) and b.argument.type != BaseType(BaseTy.Scalar): continue body.append(f"auto _s_{b.name} = {b.name}.to<scalar_t>();") ctx.append( Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t)))) if vec_loop is not None: for b in parent_ctx: if isinstance( b.argument, Argument) and b.argument.type != BaseType(BaseTy.Scalar): continue body.append( f"auto _v_{b.name} = at::vec::Vectorized<scalar_t>(_s_{b.name});" ) ctx.append( Expr( f"_v_{b.name}", NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))), )) # Setup lambda signature # NB: simplified version of ufunctor_arguments scalar_bindings = [] vec_bindings = [] for a in g.functional.func.arguments.flat_non_out: if not a.type.is_tensor_like(): continue assert a.type == BaseType(BaseTy.Tensor) scalar_bindings.append( Binding( name=a.name, nctype=NamedCType(a.name, BaseCType(scalar_t)), argument=a, )) if vec_loop is not None: vec_bindings.append( Binding( name=a.name, nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))), argument=a, )) def with_ctx(b: Sequence[Binding]) -> List[Union[Expr, Binding]]: r: List[Union[Expr, Binding]] = [] r.extend(ctx) r.extend(b) return r body_str = "\n".join(body) if vec_loop is not None: return f""" {body_str} cpu_kernel_vec(iter, [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}, [=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }} ); """ else: return f"""
def argument( a: Union[Argument, TensorOptionsArguments, SelfArgument], *, cpp_no_default_args: Set[str], method: bool, faithful: bool, has_tensor_options: bool, ) -> List[Binding]: def sub_argument( a: Union[Argument, TensorOptionsArguments, SelfArgument] ) -> List[Binding]: return argument( a, cpp_no_default_args=cpp_no_default_args, method=method, faithful=faithful, has_tensor_options=has_tensor_options, ) if isinstance(a, Argument): binds: ArgName if a.name == "memory_format" and has_tensor_options: binds = SpecialArgName.possibly_redundant_memory_format else: binds = a.name default: Optional[str] = None if a.name not in cpp_no_default_args and a.default is not None: default = default_expr(a.default, a.type) return [ Binding( nctype=argument_type(a, binds=binds), name=a.name, default=default, argument=a, ) ] elif isinstance(a, TensorOptionsArguments): if faithful: return ( sub_argument(a.dtype) + sub_argument(a.layout) + sub_argument(a.device) + sub_argument(a.pin_memory) ) else: default = None # Enforced by NativeFunction.__post_init__ assert "options" not in cpp_no_default_args if all(x.default == "None" for x in a.all()): default = "{}" elif a.dtype.default == "long": default = "at::kLong" # TODO: this is wrong return [ Binding( nctype=NamedCType("options", BaseCType(tensorOptionsT)), name="options", default=default, argument=a, ) ] elif isinstance(a, SelfArgument): if method: # Caller is responsible for installing implicit this in context! return [] else: return sub_argument(a.argument) else: assert_never(a)
def process_ir_type( typ: Type, properties: "LazyIrProperties" ) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]: """ This function takes a type from NativeFunctions and converts it for use with lazy tensor codegen. Type conversion for lazy currently consists of (1) changing at::Tensors into lazy::Values (2) wrapping everything in a BaseCType (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef) (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.) There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like' This is incomplete- there are assertions in places that it's expected to need to add more types as the codegen is used with more operators. """ if isinstance(typ, BaseType): if typ.name == BaseTy.Tensor: return BaseCType(getValueT()) elif typ.name == BaseTy.Scalar: if properties.TreatScalarsAsConstants: return BaseCType(scalarT) # at::scalar has special handling, # and is wrapped in an lazy::Value just like at::tensor return BaseCType(getValueT()) elif typ.name == BaseTy.ScalarType: return BaseCType(scalarTypeT) elif typ.name == BaseTy.int: return BaseCType(longT) elif typ.name == BaseTy.SymInt: return BaseCType(getValueT()) elif typ.name == BaseTy.bool: return BaseCType(boolT) elif typ.name == BaseTy.float: return BaseCType(doubleT) elif typ.name == BaseTy.str: return BaseCType(stringT) elif typ.name == BaseTy.Device: return BaseCType(deviceT) elif typ.name == BaseTy.Layout: return BaseCType(layoutT) elif typ.name == BaseTy.MemoryFormat: return BaseCType(memoryFormatT) else: raise AssertionError(f"TODO add support for type {repr(typ)}") elif isinstance(typ, OptionalType): return OptionalCType(process_ir_type(typ.elem, properties)) elif isinstance(typ, ListType): if str(typ.elem) == "Tensor?": # TODO(whc) is this actually correct? or should it use a Vector like above return ListCType(OptionalCType(BaseCType(getValueT()))) elif str(typ.elem) == "Tensor": # this is a TensorList which comes in from GetTensorList as a Value return BaseCType(tensorListValueT) else: return VectorCType(process_ir_type(typ.elem, properties)) else: raise AssertionError(f"unrecognized type {repr(typ)}")
# These API's mostly follow the dispatcher API, with a few quirks: # - The lambda capture has to convert reference types to value types # - While the forward lambda just directly calls into the at::_ops API # (following the dispatcher convention), the logic here for the reverse lambda # is responsible for generating both the call-site, and the declarations # (which are implemented manually in the at::functionalization::impl namespace). # The lambdas generated for each view op in the functionalization pass are of the form # [capture_arguments](outer_arguments) -> returns_type { # return name(inner_arguments); # } # Define some specific lambda input arguments. base_binding = Binding( name="base", nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))), argument=Argument(name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None), default=None, ) mutated_view_binding = Binding( name="mutated_view", nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))), argument=Argument(name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None), default=None,
def argumenttype_type( t: Type, *, mutable: bool, binds: ArgName, remove_non_owning_ref_types: bool = False ) -> NamedCType: # If it's a value type, do the value type translation r = valuetype_type( t, binds=binds, remove_non_owning_ref_types=remove_non_owning_ref_types ) if r is not None: return r if isinstance(t, BaseType): if t.name == BaseTy.Tensor: if mutable and not local.use_const_ref_for_mutable_tensors(): return NamedCType(binds, MutRefCType(BaseCType(tensorT))) else: return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) elif t.name == BaseTy.Scalar: return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) else: raise AssertionError(f"base type should have been value type {t}") elif isinstance(t, OptionalType): if str(t.elem) == "Tensor": if mutable and not local.use_const_ref_for_mutable_tensors(): return NamedCType( binds, MutRefCType(BaseCType(tensorT)) ) # TODO: fix this discrepancy else: return NamedCType( binds, ConstRefCType(OptionalCType(BaseCType(tensorT))) ) elif str(t.elem) == "Scalar": return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int": return NamedCType(binds, BaseCType(optionalIntArrayRefT)) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, OptionalCType(elem.type)) elif isinstance(t, ListType): # TODO: remove these special cases, ArrayRef fallthrough works fine if str(t.elem) == "int": if remove_non_owning_ref_types: return NamedCType(binds, VectorCType(BaseCType(longT))) else: return NamedCType(binds, BaseCType(intArrayRefT)) elif str(t.elem) == "Tensor": return NamedCType(binds, BaseCType(tensorListT)) elif str(t.elem) == "Scalar": return NamedCType(binds, ArrayRefCType(BaseCType(scalarT))) elif str(t.elem) == "SymInt": return NamedCType(binds, BaseCType(symIntArrayRefT)) elif str(t.elem) == "Dimname": return NamedCType(binds, BaseCType(dimnameListT)) elif str(t.elem) == "Tensor?": return NamedCType( binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))) ) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, ArrayRefCType(elem.type)) else: raise AssertionError(f"unrecognized type {repr(t)}")
def saved_variables( formula: str, nctypes: List[NamedCType], var_names: Tuple[str, ...], ) -> Tuple[str, Tuple[SavedAttribute, ...]]: def stride_expr(name: str) -> str: assert var_names == (name, ), ( 'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor ' 'that ".strides()" is being called on.') return f'strides_or_error({name}, "{name}")' REPLACEMENTS: List[Tuple[str, Dict[str, Any]]] = [ # replace self.sizes() with self_sizes ( r"{}.sizes\(\)", { "suffix": "_sizes", "nctype": lambda name: NamedCType(name, BaseCType(intArrayRefT)), }, ), # replace self.sym_sizes() with self_sym_sizes ( r"{}.sym_sizes\(\)", { "suffix": "_sym_sizes", "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)), }, ), # replace self->sizes() with self_sizes_opt ( r"{}->sizes\(\)", { "suffix": "_sizes_opt", "nctype": lambda name: NamedCType( name, OptionalCType(BaseCType(intArrayRefT))), "expr": lambda name: f"{name}.has_value() ? c10::optional<IntArrayRef>({name}->sizes()) : c10::nullopt", }, ), # replace self.options() with self_options ( r"{}.options\(\)", { "suffix": "_options", "nctype": lambda name: NamedCType(name, BaseCType(tensorOptionsT)), }, ), # replace zeros_like(self) with self_info ( r"zeros_like\({}\)", { "suffix": "_info", "nctype": lambda name: NamedCType(name, BaseCType(typeAndSizeT)), "expr": lambda name: name, # at save-time "res": lambda name: name + "_info.zeros()", # at eval-time }, ), # replace self.size(2) with self_size_2 ( r"{}.size\((\w+)\)", { "suffix": lambda m: "_argsize_{}".format(*m.groups()), "nctype": lambda name: NamedCType(name, BaseCType(longT)), }, ), # replace self.numel() with self_numel ( r"{}.numel\(\)", { "suffix": "_numel", "nctype": lambda name: NamedCType(name, BaseCType(longT)), }, ), # replace to_args_sizes(self) with self_args_sizes ( r"to_args_sizes\({}\)", { "suffix": "_args_sizes", "nctype": lambda name: NamedCType( name, VectorCType(VectorCType(BaseCType(longT)))), }, ), # replace to_args_scalartypes(self) with self_args_scalartypes ( r"to_args_scalartypes\({}\)", { "suffix": "_args_scalartypes", "nctype": lambda name: NamedCType(name, VectorCType(BaseCType(scalarTypeT))), }, ), # replace TensorGeometry(self) with self_geometry ( r"TensorGeometry\({}\)", { "suffix": "_geometry", "nctype": lambda name: NamedCType(name, BaseCType(tensorGeometryT)), }, ), ( r"{}.scalar_type\(\)", { "suffix": "_scalar_type", "nctype": lambda name: NamedCType(name, BaseCType(scalarTypeT)), }, ), # replace self.dim() with self_dim ( r"{}.dim\(\)", { "suffix": "_dim", "nctype": lambda name: NamedCType(name, BaseCType(longT)), }, ), # replace self.strides() with self_strides ( r"{}.strides\(\)", { "suffix": "_strides", "nctype": lambda name: NamedCType(name, BaseCType(intArrayRefT)), "expr": stride_expr, }, ), # replace self.layout() with self_layout ( r"{}.layout\(\)", { "suffix": "_layout", "nctype": lambda name: NamedCType(name, BaseCType(layoutT)), }, ), # replace self.is_conj() with self_conjugate ( r"{}.is_conj\(\)", { "suffix": "_conjugate", "nctype": lambda name: NamedCType(name, BaseCType(boolT)), }, ), ] # find which arguments need to be saved saved: List[SavedAttribute] = [] for nctype in nctypes: name = (nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name) # First search the formula for expressions which can be evaluated # when the autograd Function is created to avoid saving variables for regex, info in REPLACEMENTS: def repl(m: Match[str]) -> str: suffix: str = (info["suffix"](m) if callable(info["suffix"]) else info["suffix"]) expr: str = info["expr"](name) if "expr" in info else m.group( 0) saved.append( SavedAttribute( nctype=info["nctype"](name + suffix), expr=expr, )) if "res" in info: replacement: str = info["res"](name) return replacement return name + suffix formula = re.sub(regex.format(name), repl, formula) # c10::optional<std::string> types stored in Backward nodes must be # converted to c10::optional<c10::string_view> before being passed into # the backward function if nctype.type == OptionalCType(BaseCType(stringT)): formula = re.sub( rf"\b{name}\b", f"{name}.has_value() ? c10::optional<c10::string_view>({name}.value()) : c10::nullopt", formula, ) # Find any variables which remain in the formula and save them if re.search(IDENT_REGEX.format(name), formula): saved.append(SavedAttribute( nctype=nctype, expr=name, )) return formula, tuple(saved)
def save_var(var: SavedAttribute, is_output: bool) -> None: name = var.nctype.name type = var.nctype.type should_append_getsetdef = True should_append_raw_getsetdef = False if ( type == BaseCType(tensorT) or type == OptionalCType(BaseCType(tensorT)) or type == MutRefCType(OptionalCType(BaseCType(tensorT))) or (type == BaseCType(scalarT) and is_output) ): saved_variables.append(f"SavedVariable {name}_;") release_variables.append(f"{name}_.reset_data();") ptr = "shared_from_this()" if is_output else "" unpack.append(f"auto {name} = {name}_.unpack({ptr});") getter_definitions.append( GETTER_DEFINITION_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_SAVEDVAR ) ) getter_definitions.append( GETTER_DEFINITION_RAW_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_RAW_SAVEDVAR ) ) should_append_raw_getsetdef = True elif type == BaseCType(tensorListT): saved_variables.append(f"std::vector<SavedVariable> {name}_;") saved_variables.append(f"bool {name}_released_ = false;") # Just clear() is sufficient, we don't need to loop and clear each variable. # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. release_variables.append(f"{name}_.clear();") release_variables.append(f"{name}_released_ = true;") unpack.append(f"auto {name} = unpack_list({name}_);") asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);") getter_definitions.append( GETTER_DEFINITION_VEC_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR ) ) getter_definitions.append( GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR ) ) should_append_raw_getsetdef = True elif type == ListCType(OptionalCType(BaseCType(tensorT))): saved_variables.append(f"std::vector<SavedVariable> {name}_;") saved_variables.append(f"bool {name}_released_ = false;") # Just clear() is sufficient, we don't need to loop and clear each variable. # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. release_variables.append(f"{name}_.clear();") release_variables.append(f"{name}_released_ = true;") unpack.append(f"auto {name} = unpack_opt_list({name}_);") asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);") getter_definitions.append( GETTER_DEFINITION_VEC_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR ) ) getter_definitions.append( GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR ) ) should_append_raw_getsetdef = True elif type == BaseCType(intArrayRefT): saved_variables.append(f"std::vector<int64_t> {name};") getter_definitions.append( GETTER_DEFINITION.substitute( op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG ) ) elif type == BaseCType(optionalIntArrayRefT): saved_variables.append(f"c10::OptionalArray<int64_t> {name};") getter_definitions.append( GETTER_DEFINITION_OPT_ARRAYREF.substitute( op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG ) ) elif type == OptionalCType(BaseCType(intArrayRefT)): saved_variables.append(f"c10::OptionalArray<int64_t> {name};") getter_definitions.append( GETTER_DEFINITION_OPT_ARRAYREF.substitute( op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG ) ) elif type == OptionalCType(ArrayRefCType(BaseCType(doubleT))): saved_variables.append(f"c10::OptionalArray<double> {name};") getter_definitions.append( GETTER_DEFINITION_OPT_ARRAYREF.substitute( op=info.op, name=name, body=GETTER_BODY_ARRAYREF_DOUBLE ) ) elif type == BaseCType(longT): saved_variables.append(f"{type.cpp_type()} {name} = 0;") getter_definitions.append( GETTER_DEFINITION.substitute( op=info.op, name=name, body=GETTER_BODY_INT64_T ) ) elif type == BaseCType(stringT): saved_variables.append(f"std::string {name};") getter_definitions.append( GETTER_DEFINITION.substitute( op=info.op, name=name, body=GETTER_BODY_STRING ) ) elif type == OptionalCType(BaseCType(stringT)): saved_variables.append(f"c10::optional<std::string> {name};") getter_definitions.append( GETTER_DEFINITION_OPT.substitute( op=info.op, name=name, body=GETTER_BODY_STRING ) ) else: saved_variables.append(f"{type.cpp_type()} {name};") if type in MISC_GETTER_DEFS: getter_def, body = MISC_GETTER_DEFS[type] getter_definitions.append( getter_def.substitute(op=info.op, name=name, body=body) ) else: # Types we don't expose python bindings to yet: # TypeAndSize, at::ScalarType, TensorOptions, TensorGeometry, # std::vector<std::vector<int64_t>>, std::vector<at::ScalarType> should_append_getsetdef = False if should_append_getsetdef: py_getsetdef_structs.append( PY_GETSETDEF_STRUCT.substitute(op=info.op, name=name) ) if should_append_raw_getsetdef: py_getsetdef_structs.append( PY_RAW_GETSETDEF_STRUCT.substitute(op=info.op, name=name) )
} else if (prop.isIntegral(/*includeBool=*/false)) { return PyLong_FromLong(prop.to<int64_t>()); } else if (prop.isBoolean()) { if (prop.to<bool>()) { Py_RETURN_TRUE; } else { Py_RETURN_FALSE; } } else { PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type"); return nullptr; } """ MISC_GETTER_DEFS = { OptionalCType(BaseCType(longT)): (GETTER_DEFINITION_OPT, GETTER_BODY_INT64_T), BaseCType(doubleT): (GETTER_DEFINITION, GETTER_BODY_DOUBLE), OptionalCType(BaseCType(doubleT)): (GETTER_DEFINITION_OPT, GETTER_BODY_DOUBLE), BaseCType(boolT): (GETTER_DEFINITION, GETTER_BODY_BOOL), BaseCType(scalarT): (GETTER_DEFINITION, GETTER_BODY_SCALAR), OptionalCType(BaseCType(scalarT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SCALAR), } # These functions have backwards which cannot be traced, and so must have # their backward functions traced opaquely. # VIEW_FUNCTIONS are not traceable because they use as_strided, which # has an untraceable backwards, see # https://github.com/pytorch/pytorch/issues/4250 # TODO: This is probably not exhaustive, but it's a start UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
def solve(goal: NamedCType, *, direct: bool) -> str: def direct_solve(goal: NamedCType) -> str: return solve(goal, direct=True) if goal in ctx: # Trivial return ctx[goal] # const & is satisfied with mutable & if isinstance(goal.type, ConstRefCType): try: # WARNING: not strictly decreasing; be careful not # to add a direct conversion that goes satisfies # mutable& with const& return solve(NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct) except UnsatError: pass # mutable & is satisfied with value if isinstance(goal.type, MutRefCType): try: return solve(NamedCType(goal.name, goal.type.elem), direct=direct) except UnsatError: pass if direct: unsat(goal) # For now, all of these rules are mutually exclusive. if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))): memory_format = direct_solve( NamedCType( SpecialArgName.possibly_redundant_memory_format, OptionalCType(BaseCType(memoryFormatT)), )) # No need to join "memory_format" and "options" if the target API takes "options" directly. # Otherwise it will cause the redundant memory_format error. if options_ctype in goal_ctypes: return memory_format try: options = direct_solve(options_ctype) return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})" except UnsatError: return memory_format elif goal == NamedCType("options", BaseCType(tensorOptionsT)): dtype = direct_solve( NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT)))) pin_memory = direct_solve( NamedCType("pin_memory", OptionalCType(BaseCType(boolT)))) device = direct_solve( NamedCType("device", OptionalCType(BaseCType(deviceT)))) layout = direct_solve( NamedCType("layout", OptionalCType(BaseCType(layoutT)))) return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})" elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))): try: options = direct_solve(options_ctype) return f"optTypeMetaToScalarType({options}.dtype_opt())" except UnsatError: out_tensor = direct_solve(out_tensor_ctype) return f"{out_tensor}.scalar_type()" elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))): try: options = direct_solve(options_ctype) return f"{options}.layout_opt()" except UnsatError: out_tensor = direct_solve(out_tensor_ctype) return f"{out_tensor}.layout()" elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))): try: options = direct_solve(options_ctype) return f"{options}.device_opt()" except UnsatError: out_tensor = direct_solve(out_tensor_ctype) return f"{out_tensor}.device()" elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))): try: options = direct_solve(options_ctype) return f"{options}.pinned_memory_opt()" except UnsatError: # If we're calling a factory op from its out= variant, # We don't actually care about the value of pin_memory. out_tensor = direct_solve(out_tensor_ctype) return "c10::nullopt" # We can always do translations from value types to reference types, like vector<int> -> IntArrayRef elif goal.type == BaseCType(intArrayRefT): try: return direct_solve(NamedCType(goal.name, longVec_ctype)) except UnsatError: # We can also go SymIntArrayRef -> IntArrayRef symIntArrayRef_type = direct_solve( NamedCType(goal.name, BaseCType(symIntArrayRefT))) return f"c10::asIntArrayRefSlow({symIntArrayRef_type})" elif goal.type == BaseCType(symIntArrayRefT): return direct_solve(NamedCType(goal.name, longSymVec_ctype)) elif goal.type == BaseCType(longT): symInt_type = direct_solve( NamedCType(goal.name, BaseCType(SymIntT))) return f"{symInt_type}.expectInt()" elif goal.type == BaseCType(optionalIntArrayRefT): return direct_solve(NamedCType(goal.name, optionalLongVec_ctype)) elif goal.type == BaseCType(optionalScalarRefT): return direct_solve(NamedCType(goal.name, optionalScalar_ctype)) elif goal.type == BaseCType(optionalTensorRefT): return direct_solve(NamedCType(goal.name, optionalTensor_ctype)) # Note [translation from C++ reference to value types] # The below cases are all for when we have an argument with a reference type, # and a corresponding goal with a value type. # These are needed when we populate the inputs to a lambda capture and we need # to guarantee the lifetime of each captured argument. # We guard it with an explicit kwarg because converting to a value type is expensive # (O(n)) to convert from IntArrayRef to vector<int>), # so the caller of translate() should be explicit that they need it. if allow_expensive_conversions: if goal.type == VectorCType(BaseCType(longT)): intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT)) argname = direct_solve(intArrayRef_ctype) return f"{argname}.vec()" if goal.type == VectorCType(BaseCType(SymIntT)): symIntArrayRef_ctype = NamedCType(goal.name, BaseCType(symIntArrayRefT)) argname = direct_solve(symIntArrayRef_ctype) return f"{argname}.vec()" elif goal.type == OptionalCType(VectorCType(BaseCType(longT))): optionalIntArrayRef_ctype = NamedCType( goal.name, BaseCType(optionalIntArrayRefT)) argname = direct_solve(optionalIntArrayRef_ctype) return f"{argname}.has_value() ? c10::make_optional({argname}->vec()) : c10::nullopt" elif goal.type == OptionalCType(BaseCType(scalarT)): optionalScalarRef_ctype = NamedCType( goal.name, BaseCType(optionalScalarRefT)) argname = direct_solve(optionalScalarRef_ctype) return f"{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt" elif goal.type == OptionalCType(BaseCType(scalarT)): optionalTensorRef_ctype = NamedCType( goal.name, BaseCType(optionalTensorRefT)) argname = direct_solve(optionalTensorRef_ctype) return f"{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt" # Technically, we also need to handle cases of C++ containers holding reference types. # But there currently aren't any ops that require lambda capture codegen # With arguments like std::vector<IntArrayRef>. # If that changes, we'll have to add the translation here. # We allow const casting on tensors, since const-correctness is a bit broken for at::Tensor. # We could probably generalize this to non-tensor types too. if goal.type == MutRefCType(BaseCType(tensorT)): const_ref_tensor_ctype = NamedCType( goal.name, ConstRefCType(BaseCType(tensorT))) argname = direct_solve(const_ref_tensor_ctype) return f"const_cast<Tensor&>({argname})" unsat(goal)
def gen_one(self, f: NativeFunction) -> Optional[str]: assert not f.manual_kernel_registration if (self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f)): return None # TODO: Now, there is something interesting going on here. In the code below, # we generate CompositeExplicitAutograd implementations of functional and inplace # based on the out implementation. But in fact, out is definable by # functional too (just not very efficiently), and this is honestly the # MORE likely situation for a backend implementor. How do we pick? # Well, taking a page from Haskell type classes and default methods, # we could conceivably register a circular definition (out in terms # of functional, and functional in terms of out) and just require # someone to implement one or the other. We'd have to do a little bit # of work to not register one of these "weak" definitions unless there # is a strong definition somewhere in the DAG! So it's not implemented yet. if (self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd and f.func.kind() is SchemaKind.out): # Never generate a default implementation for out, that's what you # have to define as a backend implementor return None # Note [Direct dispatch bindings] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Signature of the non-dispatched function we'll expose in a header # (e.g., at::cpu::add). We don't generate methods (TODO: do this # when CPUTensor class is a thing); nor do we generate fallback # bindings for manual_cpp_binding functions. cpp_sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False) # Signature of the wrapper function we'll register to the dispatcher sig = NativeSignature(f.func, prefix="wrapper_") if self.target is Target.NAMESPACED_DECLARATION: result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" if cpp_sig_group.faithful_signature is not None: result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: def generate_defn(cpp_sig: CppSignature) -> str: return f""" {cpp_sig.defn()} {{ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); }} """ result = generate_defn(cpp_sig_group.signature) if cpp_sig_group.faithful_signature is not None: result += generate_defn(cpp_sig_group.faithful_signature) return result elif self.target is Target.ANONYMOUS_DEFINITION: k = f.func.kind() # Construct the body of the wrapper function with signature sig sig_body = [] # We'll use context to keep track of any variables we've brought # into scope while generating code context: List[Union[Binding, Expr]] = list(sig.arguments()) # Initialize the class corresponding to this structured # operator; feeding it the output argument(s) if it is known if self.backend_index.dispatch_key is DispatchKey.Meta: class_name = f"structured_{meta.name(self.g)}_meta_{k.name}" parent_class = f"at::meta::structured_{meta.name(self.g)}" elif (self.backend_index.dispatch_key is DispatchKey.CompositeExplicitAutograd): # TODO: dedup this branch class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}" parent_class = f"at::meta::structured_{meta.name(self.g)}" else: metadata = self.backend_index.get_kernel(self.g) assert metadata is not None class_name = f"structured_{metadata.kernel}_{k.name}" parent_class = f"{self.cpp_namespace}::structured_{metadata.kernel}" if self.backend_index.device_guard: device_check_args = itertools.chain( f.func.arguments.out, f.func.arguments.flat_positional) sig_body.append( RegisterDispatchKey.gen_device_check( f.device_check, list(device_check_args), sig.name())) if k is SchemaKind.functional: sig_body.append(f"{class_name} op;") elif k is SchemaKind.inplace: sig_body.append(f"{class_name} op(self);") elif k is SchemaKind.out: out_args_str = ", ".join(a.name for a in f.func.arguments.out) sig_body.append(f"{class_name} op({out_args_str});") # Translate the input native arguments into structured # arguments for the meta call meta_exprs = ", ".join(e.expr for e in translate( context, structured.meta_arguments(self.g), method=False)) if self.g.out.precomputed: # If this function group has precomputed elements, the meta function # returns a struct containing them which must be saved so that it # can be unpacked when generating code to call the impl. sig_body.append(f"auto precompute = op.meta({meta_exprs});") # Put all of the contents of the precompute struct into the context # so that translate will be able to return the correct args for the # call to the impl. precomputed_values = [ *self.g.out.precomputed.replace.values(), self.g.out.precomputed.add, ] for precomputed_elems in precomputed_values: for arg in precomputed_elems: context.append( Expr( expr=f"precompute.{arg.name}", type=structured.argument_type(arg, binds=arg.name), )) # Add a use of the precompute struct so FB internal compilers don't # complain that there is an unused variable. sig_body.append("(void)precompute;") else: sig_body.append(f"op.meta({meta_exprs});") # After running meta, op.outputs_ is guaranteed to be valid; # add it to the context out_args = structured.out_arguments(self.g) maybe_star = "*" if k is SchemaKind.functional else "" for i, out_arg in enumerate(out_args): assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type context.append( Expr( expr=f"{maybe_star}op.outputs_[{i}]", # TODO: Stop hardcoding that the output type is a Tensor. Note # that for the codegen here this is fine because outputs_ is # hardcoded to be tensor already type=NamedCType(out_arg.nctype.name, MutRefCType(BaseCType(tensorT))), )) # With the expanded context, do the impl call (if not a meta # function) if self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd: # TODO: https://github.com/pytorch/pytorch/issues/53023 out_sig_group = CppSignatureGroup.from_native_function( self.g.out, method=False, fallback_binding=f.manual_cpp_binding) out_sig = out_sig_group.most_faithful_signature() api_name = out_sig.name() out_exprs = ", ".join(e.expr for e in translate( context, out_sig.arguments(), method=False)) # TODO: I think this means structured won't work with method # only functions (but maybe you're saved by faithful? iunno.) # NB: Originally I wrote this as an at::redispatch call, but # I got in trouble because that meant I needed a DispatchKeySet # in the wrapper function, which meant I needed a DispatchKeySet # in the DispatchKeyFunctions declarations, but the defined API # there does NOT permit a dispatch key set. I think you can # probably unwind this by calling some function to do the TLS # fetch and get the DispatchKeySet when you don't have it, but # I didn't do it for this version sig_body.append(f"at::{api_name}({out_exprs});") elif self.backend_index.dispatch_key != DispatchKey.Meta: impl_exprs = ", ".join(e.expr for e in translate( context, structured.impl_arguments(self.g), method=False)) sig_body.append(f"op.impl({impl_exprs});") # Destructively return the final tensors # TODO: Do this in translate instead if k is SchemaKind.functional: if len(f.func.returns) == 1: ret_expr = "std::move(op.outputs_[0]).take()" # small optimization else: moved = ", ".join(f"std::move(op.outputs_[{i}]).take()" for i in range(len(f.func.returns))) ret_expr = f"std::make_tuple({moved})" elif k is SchemaKind.inplace: ret_expr = "self" elif k is SchemaKind.out: if len(f.func.returns) == 1: ret_expr = f.func.arguments.out[0].name else: refs = ", ".join(a.name for a in f.func.arguments.out) ret_expr = f"std::forward_as_tuple({refs})" sig_body.append(f"return {ret_expr};") sig_body_str = "\n".join(sig_body) # For an overview of what this template code looks like, see # https://github.com/pytorch/rfcs/pull/9 return f"""\ {self.gen_class( f, k, class_name=class_name, parent_class=parent_class, generate_super=self.g.out.structured_inherits is not None )} {sig.defn()} {{ {sig_body_str} }} """ elif self.target is Target.REGISTRATION: return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));' else: assert_never(self.target) # Silence mypy's "Missing return statement" error return None
def compute_ufunc_cuda_functors( g: NativeFunctionsGroup, ) -> Tuple[Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]], str]: # First, build the functors. ufunctor_sigs: Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]] = {} ufunctors: List[str] = [] loops = g.out.ufunc_inner_loop scalar_tensor_idx_lookup = { UfuncKey.CUDAFunctorOnSelf: 1, UfuncKey.CUDAFunctorOnOther: 0, UfuncKey.CUDAFunctor: None, } if eligible_for_binary_scalar_specialization(g): keys = [ UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther, UfuncKey.CUDAFunctor, ] else: keys = [UfuncKey.CUDAFunctor] for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]: assert k not in loops, f"cannot use {k} on non-binary function" for k in keys: # If the key was directly defined, skip functor codegen; we assume the # user already done it for us if k in loops: ufunctor_sig = UfunctorSignature( g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name) for dtype in loops[k].supported_dtypes: ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig continue # Note [ScalarOnly and Generic must match names for CUDA] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Otherwise, look in ANY of the generic entries. For simplicity of # codegen, both ScalarOnly and Generic are defined, the ufunc name # must match (if they didn't match, we'd have to generate distinct # functors per dtype, which is awful, so we're not going to do it unless # someone really forces us to) ufunc_name = None supported_dtypes = set() for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]: if lk not in loops: continue if ufunc_name is None: ufunc_name = loops[lk].name else: # See Note [ScalarOnly and Generic must match names for CUDA] assert (ufunc_name == loops[lk].name ), "ScalarOnly and Generic must have same ufunc name" supported_dtypes |= loops[lk].supported_dtypes assert ufunc_name is not None name = f"{k}_{ufunc_name}" ufunctor_sig = UfunctorSignature( g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name) for dtype in supported_dtypes: ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig ufunc_sig = UfuncSignature(g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t)) apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply ufunctors.append(f""" template <typename scalar_t> struct {ufunctor_sig.name} {{ using opmath_t = at::opmath_type<scalar_t>; {ufunctor_sig.decl_fields()} {ufunctor_sig.inline_defn_ctor()} __device__ {ufunctor_sig.decl_apply()} {{ return {ufunc_sig.call(apply_ctx)}; }} }}; """) return ufunctor_sigs, "\n".join(ufunctors)
# other scope); others are more nontrivial and may require packing/unpacking. # Some examples of non-trivial action: # # - Need the "dtype" binding? Well, maybe "dtype" isn't available # in the context, instead, "options" is, and you need to extract # it from there. (Gather) # # - Need the "context" binding? Well, maybe "context" isn't available # in the context, and you need to construct it from "dtype", "device", # etc. (Scatter) # # - Need the "memory_format" binding? Well, actually, it's available # from both "memory_format" and "options", so you had better make sure # they are consistent. (Join) options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT))) out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT))) longVec_ctype = VectorCType(BaseCType(longT)) longSymVec_ctype = VectorCType(BaseCType(SymIntT)) optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT))) optionalScalar_ctype = OptionalCType(BaseCType(scalarT)) optionalTensor_ctype = OptionalCType(BaseCType(tensorT)) class UnsatError(RuntimeError): pass # Given a set of in-scope bindings and a set of target bindings, synthesize
def returns_type(self) -> CType: # TODO: don't hardcode; return type will be inferred based on tags on # the native function return BaseCType(scalar_t)
def translate( bindings: Sequence[Union[Expr, Binding]], goals: Sequence[Union[NamedCType, Binding]], *, method: bool = False, allow_expensive_conversions: bool = False, ) -> List[Expr]: binding_exprs: List[Expr] = [] for b in bindings: if isinstance(b, Binding): binding_exprs.append(Expr( expr=b.name, type=b.nctype, )) else: binding_exprs.append(b) goal_ctypes: List[NamedCType] = [] for g in goals: if isinstance(g, Binding): goal_ctypes.append(g.nctype) else: goal_ctypes.append(g) # Add all the bindings to the context ctx: Dict[NamedCType, str] = {} for b in binding_exprs: ctx[b.type] = b.expr # While we're at it, do some simple forward inference, looking through # constructors. # # NB: When should you do forward inference versus backward inference? # The general idea: # # - Backward inference WHEN the goal gets smaller # - Forward inference WHEN the hypothesis gets smaller # # This helps ensure termination: backward inference starts with a goal # and tries to make it simpler and simpler until it's trivial; if the # goal can grow in size, we blow up to a really huge goal size. # Similarly, with forward inference we take hypotheses and decompose # them into simpler hypotheses; if hypotheses could expand in size, # we also have potential nontermination. (In the code below, forward # inference is only ever carried out at a single step, but you could # imagine repeated application of forward inference being profitable.) # # A good starting point in the literature for exploring more about proof # search are these lecture notes # https://www.cs.cmu.edu/~fp/courses/oregon-m10/04-focusing.pdf # # TODO: My kingdom for a pattern matcher # https://www.python.org/dev/peps/pep-0634/ # # TODO: This could get us in recomputation trouble if b.expr is nontrivial. # Fix this by implementing some sort of sharing so that if multiple # goals share the same expression, we only compute it once. This seems # to matter in practice as compiler is often unwilling to CSE nontrivial # expressions like scalar.to<scalar_t>() t = b.type if (isinstance(t, ConstRefCType) and isinstance(t.elem, OptionalCType) and isinstance(t.elem.elem, BaseCType) and str(t.elem.elem.type) == "at::Tensor"): ctx[NamedCType( t.elem.elem.name, ConstRefCType(BaseCType(tensorT)) )] = f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())" if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))): ctx[NamedCType( t.name, BaseCType(optionalTensorRefT) )] = f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())" if t.type == ConstRefCType(BaseCType(scalarT)): ctx[NamedCType(t.name, BaseCType(opmath_t))] = f"({b.expr}).to<opmath_t>()" if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))): ctx[NamedCType( t.name, BaseCType(optionalScalarRefT) )] = f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())" if t.type == BaseCType(scalar_t): ctx[NamedCType( t.name, BaseCType(opmath_t))] = f"static_cast<opmath_t>({b.expr})" # [Note: ITensorListRef] if t.type == BaseCType(tensorListT): ctx[NamedCType( t.name, BaseCType(iTensorListRefT))] = f"at::ITensorListRef({b.expr})" # [Note: IOptTensorListRef] if t.type == ConstRefCType(ListCType(OptionalCType( BaseCType(tensorT)))): ctx[NamedCType(t.name, BaseCType( iOptTensorListRefT))] = f"at::IOptTensorListRef({b.expr})" # Add implicit bindings if the generated code is inside a Tensor method if method: ctx[NamedCType("self", MutRefCType( BaseCType(tensorT)))] = "const_cast<Tensor&>(*this)" ctx[NamedCType("self", ConstRefCType( BaseCType(tensorT)))] = "const_cast<Tensor&>(*this)" # This is better! Byte-for-byte compat # ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this" def unsat(goal: NamedCType) -> NoReturn: ctx_desc = "\n".join(f" {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items()) raise UnsatError(f""" Failed to synthesize the expression "{goal.cpp_type()} {goal.name}". When I failed, the following bindings were available in the context: {ctx_desc} This probably means there is a missing rule in the rules of torchgen.api.translate. Check this module for more information. """) # A shitty backtracking search implementation. It's shitty because it # does backtracking via stack (bad idea!) and for the most part tries to # avoid backtracking. In particular, if # direct=True, we won't try to do any fancy synthesis, just trivial # conversions (e.g., "T a" is OK for "const T& a"). So all of the # existing rules in this function simply try to solve immediately, # and bail if things don't work out. def solve(goal: NamedCType, *, direct: bool) -> str: def direct_solve(goal: NamedCType) -> str: return solve(goal, direct=True) if goal in ctx: # Trivial return ctx[goal] # const & is satisfied with mutable & if isinstance(goal.type, ConstRefCType): try: # WARNING: not strictly decreasing; be careful not # to add a direct conversion that goes satisfies # mutable& with const& return solve(NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct) except UnsatError: pass # mutable & is satisfied with value if isinstance(goal.type, MutRefCType): try: return solve(NamedCType(goal.name, goal.type.elem), direct=direct) except UnsatError: pass if direct: unsat(goal) # For now, all of these rules are mutually exclusive. if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))): memory_format = direct_solve( NamedCType( SpecialArgName.possibly_redundant_memory_format, OptionalCType(BaseCType(memoryFormatT)), )) # No need to join "memory_format" and "options" if the target API takes "options" directly. # Otherwise it will cause the redundant memory_format error. if options_ctype in goal_ctypes: return memory_format try: options = direct_solve(options_ctype) return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})" except UnsatError: return memory_format elif goal == NamedCType("options", BaseCType(tensorOptionsT)): dtype = direct_solve( NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT)))) pin_memory = direct_solve( NamedCType("pin_memory", OptionalCType(BaseCType(boolT)))) device = direct_solve( NamedCType("device", OptionalCType(BaseCType(deviceT)))) layout = direct_solve( NamedCType("layout", OptionalCType(BaseCType(layoutT)))) return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})" elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))): try: options = direct_solve(options_ctype) return f"optTypeMetaToScalarType({options}.dtype_opt())" except UnsatError: out_tensor = direct_solve(out_tensor_ctype) return f"{out_tensor}.scalar_type()" elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))): try: options = direct_solve(options_ctype) return f"{options}.layout_opt()" except UnsatError: out_tensor = direct_solve(out_tensor_ctype) return f"{out_tensor}.layout()" elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))): try: options = direct_solve(options_ctype) return f"{options}.device_opt()" except UnsatError: out_tensor = direct_solve(out_tensor_ctype) return f"{out_tensor}.device()" elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))): try: options = direct_solve(options_ctype) return f"{options}.pinned_memory_opt()" except UnsatError: # If we're calling a factory op from its out= variant, # We don't actually care about the value of pin_memory. out_tensor = direct_solve(out_tensor_ctype) return "c10::nullopt" # We can always do translations from value types to reference types, like vector<int> -> IntArrayRef elif goal.type == BaseCType(intArrayRefT): try: return direct_solve(NamedCType(goal.name, longVec_ctype)) except UnsatError: # We can also go SymIntArrayRef -> IntArrayRef symIntArrayRef_type = direct_solve( NamedCType(goal.name, BaseCType(symIntArrayRefT))) return f"c10::asIntArrayRefSlow({symIntArrayRef_type})" elif goal.type == BaseCType(symIntArrayRefT): return direct_solve(NamedCType(goal.name, longSymVec_ctype)) elif goal.type == BaseCType(longT): symInt_type = direct_solve( NamedCType(goal.name, BaseCType(SymIntT))) return f"{symInt_type}.expectInt()" elif goal.type == BaseCType(optionalIntArrayRefT): return direct_solve(NamedCType(goal.name, optionalLongVec_ctype)) elif goal.type == BaseCType(optionalScalarRefT): return direct_solve(NamedCType(goal.name, optionalScalar_ctype)) elif goal.type == BaseCType(optionalTensorRefT): return direct_solve(NamedCType(goal.name, optionalTensor_ctype)) # Note [translation from C++ reference to value types] # The below cases are all for when we have an argument with a reference type, # and a corresponding goal with a value type. # These are needed when we populate the inputs to a lambda capture and we need # to guarantee the lifetime of each captured argument. # We guard it with an explicit kwarg because converting to a value type is expensive # (O(n)) to convert from IntArrayRef to vector<int>), # so the caller of translate() should be explicit that they need it. if allow_expensive_conversions: if goal.type == VectorCType(BaseCType(longT)): intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT)) argname = direct_solve(intArrayRef_ctype) return f"{argname}.vec()" if goal.type == VectorCType(BaseCType(SymIntT)): symIntArrayRef_ctype = NamedCType(goal.name, BaseCType(symIntArrayRefT)) argname = direct_solve(symIntArrayRef_ctype) return f"{argname}.vec()" elif goal.type == OptionalCType(VectorCType(BaseCType(longT))): optionalIntArrayRef_ctype = NamedCType( goal.name, BaseCType(optionalIntArrayRefT)) argname = direct_solve(optionalIntArrayRef_ctype) return f"{argname}.has_value() ? c10::make_optional({argname}->vec()) : c10::nullopt" elif goal.type == OptionalCType(BaseCType(scalarT)): optionalScalarRef_ctype = NamedCType( goal.name, BaseCType(optionalScalarRefT)) argname = direct_solve(optionalScalarRef_ctype) return f"{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt" elif goal.type == OptionalCType(BaseCType(scalarT)): optionalTensorRef_ctype = NamedCType( goal.name, BaseCType(optionalTensorRefT)) argname = direct_solve(optionalTensorRef_ctype) return f"{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt" # Technically, we also need to handle cases of C++ containers holding reference types. # But there currently aren't any ops that require lambda capture codegen # With arguments like std::vector<IntArrayRef>. # If that changes, we'll have to add the translation here. # We allow const casting on tensors, since const-correctness is a bit broken for at::Tensor. # We could probably generalize this to non-tensor types too. if goal.type == MutRefCType(BaseCType(tensorT)): const_ref_tensor_ctype = NamedCType( goal.name, ConstRefCType(BaseCType(tensorT))) argname = direct_solve(const_ref_tensor_ctype) return f"const_cast<Tensor&>({argname})" unsat(goal) return [Expr(solve(g, direct=False), g) for g in goal_ctypes]
def get_owning_type(t: CType) -> Tuple[CType, Callable[[str], str]]: if t == BaseCType(tensorListT): return VectorCType(BaseCType(tensorT)), lambda x: f"{x}.vec()" # There are technically other non-owning types out there (like IntArrayRef), # but functionalization only actually cares about the ones involving tensors. return t, lambda x: x