def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments], *, is_out: bool) -> List[Binding]: # Ideally, we NEVER default native functions. However, there are a number # of functions that call native:: directly and rely on the defaulting # existing. So for BC, we generate defaults for non-out variants (but not # for out variants, where it is impossible to generate an appropriate # default) should_default = not is_out if isinstance(a, Argument): default: Optional[str] = None if should_default and a.default is not None: default = cpp.default_expr(a.default, a.type) return [ Binding( nctype=argument_type(a, binds=a.name), name=a.name, default=default, argument=a, ) ] elif isinstance(a, SelfArgument): # Erase SelfArgument from the distinction return argument(a.argument, is_out=is_out) elif isinstance(a, TensorOptionsArguments): default = None if should_default: default = "{}" # TODO: Not sure why the arguments assigned here are for # TensorOptionsArguments and not the constituent pieces. It seems # to matter return [ Binding( nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))), name="dtype", default=default, argument=a, ), Binding( nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))), name="layout", default=default, argument=a, ), Binding( nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))), name="device", default=default, argument=a, ), Binding( nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))), name="pin_memory", default=default, argument=a, ), ] else: assert_never(a)
def ufunc_argument(a: Argument, compute_t: CType) -> Binding: return Binding( nctype=ufunc_type(a.type, binds=a.name, compute_t=compute_t), name=a.name, default=None, argument=a, )
def ufunctor_apply_argument(a: Argument, scalar_t: BaseCppType) -> Binding: return Binding( nctype=ufunctor_apply_type(a.type, binds=a.name, scalar_t=scalar_t), name=a.name, default=None, argument=a, )
def argument(a: Argument, *, remove_non_owning_ref_types: bool = False) -> Binding: return Binding( nctype=argument_type( a, binds=a.name, remove_non_owning_ref_types=remove_non_owning_ref_types ), name=a.name, argument=a, )
def unpack_args(f: NativeFunction) -> Tuple[List[str], List[Binding]]: body: List[str] = [] unpacked_bindings: List[Binding] = [] bindings = [ r for a in f.func.schema_order_arguments() for r in cpp.argument( a, method=False, symint=True, cpp_no_default_args=set(), faithful=False, has_tensor_options=False, ) ] for i, binding in enumerate(bindings): assert not isinstance(binding.argument, SelfArgument) if isinstance(binding.argument, TensorOptionsArguments): raise RuntimeError("VariableKernel shouldn't take TensorOptions") is_nullable = binding.argument.type.is_nullable() if not binding.argument.type.is_tensor_like() or is_nullable: unpacked_bindings.append(binding) continue is_tensor_list = is_tensor_list_type(binding.argument.type) ref = (not is_nullable) and not is_tensor_list suffix = "_opt" if is_nullable and not is_tensor_list else "" body.append( UNPACK_TENSOR.substitute( arg_name=binding.name, arg_pos=i, suffix=suffix, ref="&" if ref else "", ) ) unpacked_bindings.append( Binding( name=unpacked_name(binding.name), nctype=binding.nctype, argument=binding.argument, default=binding.default, ) ) return body, unpacked_bindings
def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments]) -> List[Binding]: if isinstance(a, Argument): return [ Binding( nctype=argument_type(a, binds=a.name), name=a.name, default=None, argument=a, ) ] elif isinstance(a, SelfArgument): return argument(a.argument) elif isinstance(a, TensorOptionsArguments): raise AssertionError("structured kernels don't support TensorOptions yet") else: assert_never(a)
def argument( a: Union[Argument, TensorOptionsArguments, SelfArgument], *, cpp_no_default_args: Set[str], method: bool, faithful: bool, has_tensor_options: bool, ) -> List[Binding]: def sub_argument( a: Union[Argument, TensorOptionsArguments, SelfArgument] ) -> List[Binding]: return argument( a, cpp_no_default_args=cpp_no_default_args, method=method, faithful=faithful, has_tensor_options=has_tensor_options, ) if isinstance(a, Argument): binds: ArgName if a.name == "memory_format" and has_tensor_options: binds = SpecialArgName.possibly_redundant_memory_format else: binds = a.name default: Optional[str] = None if a.name not in cpp_no_default_args and a.default is not None: default = default_expr(a.default, a.type) return [ Binding( nctype=argument_type(a, binds=binds), name=a.name, default=default, argument=a, ) ] elif isinstance(a, TensorOptionsArguments): if faithful: return ( sub_argument(a.dtype) + sub_argument(a.layout) + sub_argument(a.device) + sub_argument(a.pin_memory) ) else: default = None # Enforced by NativeFunction.__post_init__ assert "options" not in cpp_no_default_args if all(x.default == "None" for x in a.all()): default = "{}" elif a.dtype.default == "long": default = "at::kLong" # TODO: this is wrong return [ Binding( nctype=NamedCType("options", BaseCType(tensorOptionsT)), name="options", default=default, argument=a, ) ] elif isinstance(a, SelfArgument): if method: # Caller is responsible for installing implicit this in context! return [] else: return sub_argument(a.argument) else: assert_never(a)
def compute_ufunc_cpu_dtype_body( g: NativeFunctionsGroup, dtype: ScalarType, inner_loops: Dict[UfuncKey, UfuncSignature], parent_ctx: Sequence[Binding], ) -> str: assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}" assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector} scalar_loop = inner_loops[UfuncKey.CPUScalar] vec_loop = None if UfuncKey.CPUVector in inner_loops: vec_loop = inner_loops[UfuncKey.CPUVector] # NB: We DON'T use translate here, because translate is # incapable of CSE'ing the scalar accesses in case it is also # used by Vectorized; also, the unpacking here is very simple # and only affects Scalar; everything else is implicitly captured # by the lambda # Setup scalar in scope body = [] ctx = [] for b in parent_ctx: if isinstance(b.argument, Argument) and b.argument.type != BaseType(BaseTy.Scalar): continue body.append(f"auto _s_{b.name} = {b.name}.to<scalar_t>();") ctx.append( Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t)))) if vec_loop is not None: for b in parent_ctx: if isinstance( b.argument, Argument) and b.argument.type != BaseType(BaseTy.Scalar): continue body.append( f"auto _v_{b.name} = at::vec::Vectorized<scalar_t>(_s_{b.name});" ) ctx.append( Expr( f"_v_{b.name}", NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))), )) # Setup lambda signature # NB: simplified version of ufunctor_arguments scalar_bindings = [] vec_bindings = [] for a in g.functional.func.arguments.flat_non_out: if not a.type.is_tensor_like(): continue assert a.type == BaseType(BaseTy.Tensor) scalar_bindings.append( Binding( name=a.name, nctype=NamedCType(a.name, BaseCType(scalar_t)), argument=a, )) if vec_loop is not None: vec_bindings.append( Binding( name=a.name, nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))), argument=a, )) def with_ctx(b: Sequence[Binding]) -> List[Union[Expr, Binding]]: r: List[Union[Expr, Binding]] = [] r.extend(ctx) r.extend(b) return r body_str = "\n".join(body) if vec_loop is not None: return f""" {body_str} cpu_kernel_vec(iter, [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}, [=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }} ); """ else: return f"""
# - While the forward lambda just directly calls into the at::_ops API # (following the dispatcher convention), the logic here for the reverse lambda # is responsible for generating both the call-site, and the declarations # (which are implemented manually in the at::functionalization::impl namespace). # The lambdas generated for each view op in the functionalization pass are of the form # [capture_arguments](outer_arguments) -> returns_type { # return name(inner_arguments); # } # Define some specific lambda input arguments. base_binding = Binding( name="base", nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))), argument=Argument(name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None), default=None, ) mutated_view_binding = Binding( name="mutated_view", nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))), argument=Argument(name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None), default=None, ) mutated_view_idx_binding = Binding(