def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema: # Generating an out= schema from a mutable schema. assert func.kind() == SchemaKind.mutable # The new out= schema has: # - Any non-aliased tensor-like returns are converted to mutable, aliased out= arguments # (if the argument is a tensor then we also return it for method chaining, # otherwise we return nothing) # - an "out" overload name # # Note that: # (1) This also means that we can *only* generate an out= variant from a mutable schema # if the mutable schema has at least one tensor-like non-aliasing return. # (2) The generated out= variant still has mutable positional arguments, # but if necessary we could probably add another out= variant that also # functionalizes the mutable arguments (a functional_out variant) new_returns, new_out_args = generate_out_args_from_schema(func) return FunctionSchema( name=func.name.remove_inplace().with_overload( get_expected_out_variant_overload_name(func.name.overload_name) ), arguments=func.arguments.with_out_args(new_out_args), returns=tuple(new_returns), )
def name(func: FunctionSchema, *, faithful_name_for_out_overloads: bool = False) -> str: name = str(func.name.name) if func.is_symint_fn(): name += "_symint" if func.is_out_fn(): if faithful_name_for_out_overloads: name += "_outf" else: name += "_out" return name
def generate_test_value_definitions(schema: FunctionSchema, index: int) -> str: assert not schema.is_out_fn() schema_name = schema.name.name.base arg_map = {} for arg in schema.schema_order_arguments(): test_value_exp = test_value_expression(arg.type, index, schema_name) arg_map[arg.name] = test_value_exp config.override_test_values(arg_map, schema_name, index) arg_populations = [] for arg_name, arg_value in arg_map.items(): arg_populations.append(f"auto {arg_name}{index} = {arg_value}") return ";\n ".join(arg_populations) + ";"
def name(func: FunctionSchema) -> str: name = str(func.name.name) # TODO: delete this! if func.is_out_fn(): name += "_out" if func.name.overload_name: name += f"_{func.name.overload_name}" return name
def arguments(func: FunctionSchema, *, symint: bool) -> List[Binding]: args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] args.extend(func.arguments.non_out) args.extend(func.arguments.out) return [ r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn()) ]
def functional_to_out_signature(func: FunctionSchema) -> FunctionSchema: # Generating an out= schema from a functional schema. assert func.kind() == SchemaKind.functional new_returns, new_out_args = generate_out_args_from_schema(func) # The new out= schema has: # - one or more new out argument(s) with the same type as returns (but with a mutable annotation) # - The returns now alias the out= arguments # - an "_out" overload name return FunctionSchema( name=func.name.with_overload( get_expected_out_variant_overload_name(func.name.overload_name) ), arguments=func.arguments.signature().with_out_args( new_out_args, ), returns=tuple(new_returns), )
def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> List[str]: aliased_rets = func.aliased_return_names() non_aliased_names = [] is_out_var_a_tuple = len(func.returns) > 1 for (i, r) in enumerate(aliased_rets): if r is None: non_aliased_names.append( f"std::get<{i}>({out_var})" if is_out_var_a_tuple else out_var ) return non_aliased_names
def generate_arg_extraction(schema: FunctionSchema) -> str: arg_populations = [] for i, arg in enumerate(schema.schema_order_arguments()): maybe_method = ivalue_type_conversion_method(arg.type) assert maybe_method is_reference, type_conversion_method = maybe_method reference = "&" if is_reference else "" arg_populations.append( f"const auto{reference} {arg.name} = p_node->Input({i}).{type_conversion_method}" ) return ";\n ".join(arg_populations) + ";"
def self_to_out_signature(func: FunctionSchema) -> FunctionSchema: # Generating an out= schema from an inplace schema. assert func.kind() == SchemaKind.inplace assert func.arguments.self_arg is not None # The new out= schema has: # - a new out argument with the same type as "func" (but with a mutable annotation) # - The returns (if any) now alias the out= argument instead of "func" # - an "out" overload name return FunctionSchema( name=func.name.remove_inplace().with_overload( "out" if not func.name.overload_name else f"{func.name.overload_name}_out"), arguments=func.arguments.remove_self_annotation().with_out_args([ Argument( name="out", type=func.arguments.self_arg.argument.type, default=None, annotation=func.arguments.self_arg.argument.annotation, ) ]), returns=func.returns, )
def generate_non_native_lazy_ir_nodes(non_native: List[Dict[str, Any]], gen_lazy_ir: GenLazyIR) -> List[str]: """Generate the non-native lazy IR node classes""" nodes = [] for op in non_native: # Set default properties for Non-Native IRs properties = LazyIrProperties("ShapeCache", "CanBeReused", "LowerDeclOnly") for p in op.get("properties", []): setattr(properties, p, True) schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties) schema.opkind = op.get("opkind") nodes.append(gen_lazy_ir.gen(schema)[0]) return nodes
def generate_test_ir_arguments( schema: FunctionSchema, ) -> List[Tuple[str, Optional[str]]]: def ir_argument(arg: Argument) -> Tuple[str, Optional[str]]: t = arg.type add_optional = False if isinstance(t, OptionalType): t = t.elem add_optional = True assert isinstance(t, BaseType) type_str = None if t.name in generate_test_ir_arguments_base_ty_to_type_str_: type_str = generate_test_ir_arguments_base_ty_to_type_str_[t.name] if type_str and add_optional: type_str = f"{type_str}?" return ("%" + arg.name, type_str) return [ir_argument(arg) for arg in schema.schema_order_arguments()]
def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str: assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas" return f"ufunc_{func.name.name}_{dispatch_key}"
def generate_test_value_names(schema: FunctionSchema, index: int) -> str: assert not schema.is_out_fn() return ",".join(f"{arg.name}{index}" for arg in schema.schema_order_arguments())
def load_deprecated_signatures( pairs: Sequence[PythonSignatureNativeFunctionPair], deprecated_yaml_path: str, *, method: bool, pyi: bool, ) -> List[PythonSignatureNativeFunctionPair]: # The deprecated.yaml doesn't have complete type information, we need # find and leverage the original ATen signature (to which it delegates # the call) to generate the full python signature. # We join the deprecated and the original signatures using type-only form. # group the original ATen signatures by name grouped: Dict[str, List[PythonSignatureNativeFunctionPair]] = defaultdict(list) for pair in pairs: grouped[pair.signature.name].append(pair) # find matching original signatures for each deprecated signature results: List[PythonSignatureNativeFunctionPair] = [] with open(deprecated_yaml_path, "r") as f: deprecated_defs = yaml.load(f, Loader=YamlLoader) for deprecated in deprecated_defs: schema = FunctionSchema.parse(deprecated["name"]) aten_name, call_args = split_name_params(deprecated["aten"]) is_out = aten_name.endswith("_out") if is_out: aten_name = aten_name.replace("_out", "") # HACK: these are fixed constants used to pass the the aten function. # The type must be known ahead of time known_constants = { "1": Type.parse("Scalar"), } schema_args_by_name = {a.name: a for a in schema.arguments.flat_all} for name in call_args: assert (name in schema_args_by_name or name in known_constants ), f"deprecation definiton: Unrecognized value {name}" # Map deprecated signature arguments to their aten signature and test # if the types and alias annotation match. def is_schema_compatible(aten_schema: FunctionSchema, ) -> bool: arguments: Iterable[Argument] if is_out: arguments = itertools.chain(aten_schema.arguments.out, aten_schema.arguments.flat_non_out) else: arguments = aten_schema.arguments.flat_all for i, arg in enumerate(arguments): if i < len(call_args): arg_name = call_args[i] if arg_name in known_constants: schema_type = known_constants[arg_name] schema_annotation = None else: schema_arg = schema_args_by_name[arg_name] schema_type = schema_arg.type schema_annotation = schema_arg.annotation if schema_type != arg.type or schema_annotation != arg.annotation: return False else: if arg.default is None: return False return len(schema.returns) == len(aten_schema.returns) and all( a == b for a, b in zip(schema.returns, aten_schema.returns)) any_schema_found = False for pair in grouped[aten_name]: if not is_schema_compatible(pair.function.func): continue any_schema_found = True python_sig = signature_from_schema( schema, category_override=pair.function.category_override, method=method, pyi=pyi, ) results.append( PythonSignatureNativeFunctionPair( signature=PythonSignatureDeprecated( name=python_sig.name, input_args=python_sig.input_args, input_kwargs=python_sig.input_kwargs, output_args=python_sig.output_args, tensor_options_args=python_sig.tensor_options_args, method=python_sig.method, deprecated_schema=schema, deprecated_args_exprs=tuple(call_args), returns=python_sig.returns, ), function=pair.function, )) assert ( any_schema_found ), f"No native function with name {aten_name} matched signature:\n {str(schema)}" return results
def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema: # Generating an out= schema from a mutable schema. assert func.kind() == SchemaKind.mutable # The new out= schema has: # - Any non-aliased tensor-like returns are converted to mutable, aliased out= arguments # (if the argument is a tensor then we also return it for method chaining, # otherwise we return nothing) # - an "out" overload name # # Note that: # (1) This also means that we can *only* generate an out= variant from a mutable schema # if the mutable schema has at least one tensor-like non-aliasing return. # (2) The generated out= variant still has mutable positional arguments, # but if necessary we could probably add another out= variant that also # functionalizes the mutable arguments (a functional_out variant) # More of a sanity check - our existing restrictions on schemas should enforce that # mutable schema kinds never return their mutable arguments. assert not any(r.annotation is not None and r.annotation.is_write for r in func.returns) tensorlike_rets = [r for r in func.returns if r.type.is_tensor_like()] assert len(tensorlike_rets) > 0 used_annotations = concatMap( lambda a: [] if a.annotation is None else a.annotation.alias_set, func.arguments.flat_all, ) valid_annotations = [ x for x in "abcdefghijklmnopqrstuvwxyz" if x not in used_annotations ] all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns) new_out_args: List[Argument] = [] # The end result of new_returns is that: # - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added. # - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any). new_returns: List[Return] = [] for (i, r) in enumerate(func.returns): if r.type.is_tensor_like(): new_out = Argument( name=f"out{i}", type=r.type, default=None, annotation=Annotation.parse(f"{valid_annotations[i]}!"), ) new_out_args.append(new_out) if all_rets_are_tensors: # The convention for out= schemas is that they only return their out arguments # if the return is a plain Tensor (or if it's a tuple of plain Tensors) new_ret = Return(name=None, type=new_out.type, annotation=new_out.annotation) new_returns.append(new_ret) else: new_returns.append(r) return FunctionSchema( name=func.name.remove_inplace().with_overload( "out" if not func.name.overload_name else f"{func.name.overload_name}_out"), arguments=func.arguments.with_out_args(new_out_args), returns=tuple(new_returns), )