def _write_custom_ops_registrations( self, writer: opgenwriter.SourceWriter, generated_funcs: List[MappedOpFunction] ): writer.writeline() writer.writeline("TORCH_LIBRARY(ort, m) {") writer.push_indent() for mapped_func in generated_funcs: cpp_func = mapped_func.cpp_func writer.write("m.def(") writer.writeline(f'"{cpp_func.identifier.value}", &{cpp_func.identifier.value});') writer.pop_indent() writer.writeline("}") writer.writeline()
def _write_function_registrations(self, writer: opgenwriter.SourceWriter, generated_funcs: List[MappedOpFunction]): writer.writeline() writer.writeline("TORCH_LIBRARY_IMPL(aten, ORT, m) {") writer.push_indent() for mapped_func in generated_funcs: cpp_func, torch_func = mapped_func.cpp_func, mapped_func.cpp_func.torch_func if mapped_func.op_namespace: reg_function_arg = f"{mapped_func.op_namespace}::" else: reg_function_arg = "" reg_function_arg += cpp_func.identifier.value writer.write("m.impl(") reg_function_arg = f"TORCH_FN({reg_function_arg})" writer.writeline(f'"{torch_func.identifier.value}", {reg_function_arg});') writer.pop_indent() writer.writeline("}") writer.writeline()
def _write_cpu_fall_back(self, writer: opgenwriter.SourceWriter, mapped_func: MappedOpFunction): onnx_op, cpp_func = mapped_func.onnx_op, mapped_func.cpp_func # return at::native::call_fallback_fn< # &at::native::cpu_fallback, # ATEN_OP(eq_Tensor)>::call(self, other); writer.writeline("return native::call_fallback_fn<") writer.push_indent() writer.writeline("&native::cpu_fallback,") writer.write("ATEN_OP(") writer.write(cpp_func.identifier.value) writer.write(")>::call(") params = ", ".join([p.member.identifier.value for p in cpp_func.parameters if p.member.identifier]) writer.write(params) writer.writeline(");") writer.pop_indent()
def _write_function_signature(self, writer: opgenwriter.SourceWriter, cpp_func: ast.FunctionDecl): if cpp_func.torch_func: writer.writeline(f"// {cpp_func.torch_func.torch_schema}") cpp_func.return_type.write(writer) writer.write(f" {cpp_func.identifier.value}(") writer.push_indent() for param_list_member in cpp_func.parameters: writer.writeline() if isinstance(param_list_member.member.parameter_type, ast.KWArgsSentinelType): writer.write("// ") param_list_member.write(writer) writer.pop_indent() writer.write(")")