Beispiel #1
0
def generate_non_native_lazy_ir_nodes(non_native: List[Dict[str, Any]],
                                      gen_lazy_ir: GenLazyIR) -> List[str]:
    """Generate the non-native lazy IR node classes"""
    nodes = []
    for op in non_native:
        # Set default properties for Non-Native IRs
        properties = LazyIrProperties("ShapeCache", "CanBeReused",
                                      "LowerDeclOnly")
        for p in op.get("properties", []):
            setattr(properties, p, True)

        schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties)
        schema.opkind = op.get("opkind")
        nodes.append(gen_lazy_ir.gen(schema)[0])

    return nodes
Beispiel #2
0
def load_deprecated_signatures(
    pairs: Sequence[PythonSignatureNativeFunctionPair],
    deprecated_yaml_path: str,
    *,
    method: bool,
    pyi: bool,
) -> List[PythonSignatureNativeFunctionPair]:
    # The deprecated.yaml doesn't have complete type information, we need
    # find and leverage the original ATen signature (to which it delegates
    # the call) to generate the full python signature.
    # We join the deprecated and the original signatures using type-only form.

    # group the original ATen signatures by name
    grouped: Dict[str,
                  List[PythonSignatureNativeFunctionPair]] = defaultdict(list)
    for pair in pairs:
        grouped[pair.signature.name].append(pair)

    # find matching original signatures for each deprecated signature
    results: List[PythonSignatureNativeFunctionPair] = []

    with open(deprecated_yaml_path, "r") as f:
        deprecated_defs = yaml.load(f, Loader=YamlLoader)

    for deprecated in deprecated_defs:
        schema = FunctionSchema.parse(deprecated["name"])
        aten_name, call_args = split_name_params(deprecated["aten"])
        is_out = aten_name.endswith("_out")
        if is_out:
            aten_name = aten_name.replace("_out", "")

        # HACK: these are fixed constants used to pass the the aten function.
        # The type must be known ahead of time
        known_constants = {
            "1": Type.parse("Scalar"),
        }
        schema_args_by_name = {a.name: a for a in schema.arguments.flat_all}
        for name in call_args:
            assert (name in schema_args_by_name or name in known_constants
                    ), f"deprecation definiton: Unrecognized value {name}"

        # Map deprecated signature arguments to their aten signature and test
        # if the types and alias annotation match.
        def is_schema_compatible(aten_schema: FunctionSchema, ) -> bool:
            arguments: Iterable[Argument]
            if is_out:
                arguments = itertools.chain(aten_schema.arguments.out,
                                            aten_schema.arguments.flat_non_out)
            else:
                arguments = aten_schema.arguments.flat_all

            for i, arg in enumerate(arguments):
                if i < len(call_args):
                    arg_name = call_args[i]
                    if arg_name in known_constants:
                        schema_type = known_constants[arg_name]
                        schema_annotation = None
                    else:
                        schema_arg = schema_args_by_name[arg_name]
                        schema_type = schema_arg.type
                        schema_annotation = schema_arg.annotation

                    if schema_type != arg.type or schema_annotation != arg.annotation:
                        return False
                else:
                    if arg.default is None:
                        return False

            return len(schema.returns) == len(aten_schema.returns) and all(
                a == b for a, b in zip(schema.returns, aten_schema.returns))

        any_schema_found = False
        for pair in grouped[aten_name]:
            if not is_schema_compatible(pair.function.func):
                continue
            any_schema_found = True

            python_sig = signature_from_schema(
                schema,
                category_override=pair.function.category_override,
                method=method,
                pyi=pyi,
            )

            results.append(
                PythonSignatureNativeFunctionPair(
                    signature=PythonSignatureDeprecated(
                        name=python_sig.name,
                        input_args=python_sig.input_args,
                        input_kwargs=python_sig.input_kwargs,
                        output_args=python_sig.output_args,
                        tensor_options_args=python_sig.tensor_options_args,
                        method=python_sig.method,
                        deprecated_schema=schema,
                        deprecated_args_exprs=tuple(call_args),
                        returns=python_sig.returns,
                    ),
                    function=pair.function,
                ))
        assert (
            any_schema_found
        ), f"No native function with name {aten_name} matched signature:\n  {str(schema)}"

    return results