def _write_function_signature(self, writer: opgenwriter.SourceWriter, cpp_func: ast.FunctionDecl): if cpp_func.torch_func: writer.writeline(f"// {cpp_func.torch_func.torch_schema}") cpp_func.return_type.write(writer) writer.write(f" {cpp_func.identifier.value}(") writer.push_indent() for param_list_member in cpp_func.parameters: writer.writeline() if isinstance(param_list_member.member.parameter_type, ast.KWArgsSentinelType): writer.write("// ") param_list_member.write(writer) writer.pop_indent() writer.write(")")
def _write_custom_ops_registrations( self, writer: opgenwriter.SourceWriter, generated_funcs: List[MappedOpFunction] ): writer.writeline() writer.writeline("TORCH_LIBRARY(ort, m) {") writer.push_indent() for mapped_func in generated_funcs: cpp_func = mapped_func.cpp_func writer.write("m.def(") writer.writeline(f'"{cpp_func.identifier.value}", &{cpp_func.identifier.value});') writer.pop_indent() writer.writeline("}") writer.writeline()
def _write_cpu_fall_back(self, writer: opgenwriter.SourceWriter, mapped_func: MappedOpFunction): onnx_op, cpp_func = mapped_func.onnx_op, mapped_func.cpp_func # return at::native::call_fallback_fn< # &at::native::cpu_fallback, # ATEN_OP(eq_Tensor)>::call(self, other); writer.writeline("return native::call_fallback_fn<") writer.push_indent() writer.writeline("&native::cpu_fallback,") writer.write("ATEN_OP(") writer.write(cpp_func.identifier.value) writer.write(")>::call(") params = ", ".join([p.member.identifier.value for p in cpp_func.parameters if p.member.identifier]) writer.write(params) writer.writeline(");") writer.pop_indent()
def run(self, cpp_parser: parser.CPPParser, writer: opgenwriter.SourceWriter): self._write_file_prelude(writer) generated_funcs = [] current_ns = None for mapped_func in self._parse_mapped_function_decls(cpp_parser): del self._mapped_ops[mapped_func.mapped_op_name] generated_funcs.append(mapped_func) ns = mapped_func.op_namespace if current_ns and current_ns != ns: current_ns = None writer.pop_namespace() if ns != current_ns: current_ns = ns writer.writeline() writer.push_namespace(ns) writer.writeline() self._write_function_signature(writer, mapped_func.cpp_func) if mapped_func.signature_only: writer.writeline(";") else: writer.writeline(" {") writer.push_indent() self._write_function_body(writer, mapped_func) writer.pop_indent() writer.writeline("}") if current_ns: current_ns = None writer.pop_namespace() if not self._custom_ops: self._write_function_registrations(writer, generated_funcs) else: self._write_custom_ops_registrations(writer, generated_funcs) self._write_file_postlude(writer) if len(self._mapped_ops) > 0: raise Exception( "Torch operation(s) could not be parsed for mapping: " + ", ".join([f"'{o}'" for o in self._mapped_ops.keys()]) )
def _write_function_registrations(self, writer: opgenwriter.SourceWriter, generated_funcs: List[MappedOpFunction]): writer.writeline() writer.writeline("TORCH_LIBRARY_IMPL(aten, ORT, m) {") writer.push_indent() for mapped_func in generated_funcs: cpp_func, torch_func = mapped_func.cpp_func, mapped_func.cpp_func.torch_func if mapped_func.op_namespace: reg_function_arg = f"{mapped_func.op_namespace}::" else: reg_function_arg = "" reg_function_arg += cpp_func.identifier.value writer.write("m.impl(") reg_function_arg = f"TORCH_FN({reg_function_arg})" writer.writeline(f'"{torch_func.identifier.value}", {reg_function_arg});') writer.pop_indent() writer.writeline("}") writer.writeline()