Exemple #1
0
    def _write_custom_ops_registrations(
        self, writer: opgenwriter.SourceWriter, generated_funcs: List[MappedOpFunction]
    ):
        writer.writeline()
        writer.writeline("TORCH_LIBRARY(ort, m) {")
        writer.push_indent()

        for mapped_func in generated_funcs:
            cpp_func = mapped_func.cpp_func
            writer.write("m.def(")
            writer.writeline(f'"{cpp_func.identifier.value}", &{cpp_func.identifier.value});')

        writer.pop_indent()
        writer.writeline("}")
        writer.writeline()
Exemple #2
0
    def _write_function_registrations(self, writer: opgenwriter.SourceWriter, generated_funcs: List[MappedOpFunction]):
        writer.writeline()
        writer.writeline("TORCH_LIBRARY_IMPL(aten, ORT, m) {")
        writer.push_indent()

        for mapped_func in generated_funcs:
            cpp_func, torch_func = mapped_func.cpp_func, mapped_func.cpp_func.torch_func

            if mapped_func.op_namespace:
                reg_function_arg = f"{mapped_func.op_namespace}::"
            else:
                reg_function_arg = ""
            reg_function_arg += cpp_func.identifier.value

            writer.write("m.impl(")
            reg_function_arg = f"TORCH_FN({reg_function_arg})"

            writer.writeline(f'"{torch_func.identifier.value}", {reg_function_arg});')

        writer.pop_indent()
        writer.writeline("}")
        writer.writeline()
Exemple #3
0
    def _write_cpu_fall_back(self, writer: opgenwriter.SourceWriter, mapped_func: MappedOpFunction):
        onnx_op, cpp_func = mapped_func.onnx_op, mapped_func.cpp_func
        # return at::native::call_fallback_fn<
        #  &at::native::cpu_fallback,
        #  ATEN_OP(eq_Tensor)>::call(self, other);
        writer.writeline("return native::call_fallback_fn<")
        writer.push_indent()
        writer.writeline("&native::cpu_fallback,")
        writer.write("ATEN_OP(")
        writer.write(cpp_func.identifier.value)
        writer.write(")>::call(")

        params = ", ".join([p.member.identifier.value for p in cpp_func.parameters if p.member.identifier])
        writer.write(params)
        writer.writeline(");")
        writer.pop_indent()
Exemple #4
0
 def _write_function_signature(self, writer: opgenwriter.SourceWriter, cpp_func: ast.FunctionDecl):
     if cpp_func.torch_func:
         writer.writeline(f"// {cpp_func.torch_func.torch_schema}")
     cpp_func.return_type.write(writer)
     writer.write(f" {cpp_func.identifier.value}(")
     writer.push_indent()
     for param_list_member in cpp_func.parameters:
         writer.writeline()
         if isinstance(param_list_member.member.parameter_type, ast.KWArgsSentinelType):
             writer.write("// ")
         param_list_member.write(writer)
     writer.pop_indent()
     writer.write(")")