示例#1
0
 def _write_function_signature(self, writer: opgenwriter.SourceWriter, cpp_func: ast.FunctionDecl):
     if cpp_func.torch_func:
         writer.writeline(f"// {cpp_func.torch_func.torch_schema}")
     cpp_func.return_type.write(writer)
     writer.write(f" {cpp_func.identifier.value}(")
     writer.push_indent()
     for param_list_member in cpp_func.parameters:
         writer.writeline()
         if isinstance(param_list_member.member.parameter_type, ast.KWArgsSentinelType):
             writer.write("// ")
         param_list_member.write(writer)
     writer.pop_indent()
     writer.write(")")
示例#2
0
    def _write_custom_ops_registrations(
        self, writer: opgenwriter.SourceWriter, generated_funcs: List[MappedOpFunction]
    ):
        writer.writeline()
        writer.writeline("TORCH_LIBRARY(ort, m) {")
        writer.push_indent()

        for mapped_func in generated_funcs:
            cpp_func = mapped_func.cpp_func
            writer.write("m.def(")
            writer.writeline(f'"{cpp_func.identifier.value}", &{cpp_func.identifier.value});')

        writer.pop_indent()
        writer.writeline("}")
        writer.writeline()
示例#3
0
    def _write_cpu_fall_back(self, writer: opgenwriter.SourceWriter, mapped_func: MappedOpFunction):
        onnx_op, cpp_func = mapped_func.onnx_op, mapped_func.cpp_func
        # return at::native::call_fallback_fn<
        #  &at::native::cpu_fallback,
        #  ATEN_OP(eq_Tensor)>::call(self, other);
        writer.writeline("return native::call_fallback_fn<")
        writer.push_indent()
        writer.writeline("&native::cpu_fallback,")
        writer.write("ATEN_OP(")
        writer.write(cpp_func.identifier.value)
        writer.write(")>::call(")

        params = ", ".join([p.member.identifier.value for p in cpp_func.parameters if p.member.identifier])
        writer.write(params)
        writer.writeline(");")
        writer.pop_indent()
示例#4
0
    def run(self, cpp_parser: parser.CPPParser, writer: opgenwriter.SourceWriter):
        self._write_file_prelude(writer)

        generated_funcs = []
        current_ns = None

        for mapped_func in self._parse_mapped_function_decls(cpp_parser):
            del self._mapped_ops[mapped_func.mapped_op_name]
            generated_funcs.append(mapped_func)

            ns = mapped_func.op_namespace
            if current_ns and current_ns != ns:
                current_ns = None
                writer.pop_namespace()
            if ns != current_ns:
                current_ns = ns
                writer.writeline()
                writer.push_namespace(ns)

            writer.writeline()

            self._write_function_signature(writer, mapped_func.cpp_func)
            if mapped_func.signature_only:
                writer.writeline(";")
            else:
                writer.writeline(" {")
                writer.push_indent()
                self._write_function_body(writer, mapped_func)
                writer.pop_indent()
                writer.writeline("}")

        if current_ns:
            current_ns = None
            writer.pop_namespace()

        if not self._custom_ops:
            self._write_function_registrations(writer, generated_funcs)
        else:
            self._write_custom_ops_registrations(writer, generated_funcs)
        self._write_file_postlude(writer)

        if len(self._mapped_ops) > 0:
            raise Exception(
                "Torch operation(s) could not be parsed for mapping: "
                + ", ".join([f"'{o}'" for o in self._mapped_ops.keys()])
            )
示例#5
0
    def _write_function_registrations(self, writer: opgenwriter.SourceWriter, generated_funcs: List[MappedOpFunction]):
        writer.writeline()
        writer.writeline("TORCH_LIBRARY_IMPL(aten, ORT, m) {")
        writer.push_indent()

        for mapped_func in generated_funcs:
            cpp_func, torch_func = mapped_func.cpp_func, mapped_func.cpp_func.torch_func

            if mapped_func.op_namespace:
                reg_function_arg = f"{mapped_func.op_namespace}::"
            else:
                reg_function_arg = ""
            reg_function_arg += cpp_func.identifier.value

            writer.write("m.impl(")
            reg_function_arg = f"TORCH_FN({reg_function_arg})"

            writer.writeline(f'"{torch_func.identifier.value}", {reg_function_arg});')

        writer.pop_indent()
        writer.writeline("}")
        writer.writeline()