Esempio n. 1
0
    def _parse_function_decls(self, cpp_parser: parser.CPPParser):
        # Parse the C++ declarations
        tu = cpp_parser.parse_translation_unit()

        # Parse the Torch schema from the JSON comment that follows each C++ decl
        # and link associated Torch and C++ decls (functions, parameters, returns)
        for cpp_func in tu:
            if cpp_func.semicolon and cpp_func.semicolon.trailing_trivia:
                for trivia in cpp_func.semicolon.trailing_trivia:
                    if trivia.kind == lexer.TokenKind.SINGLE_LINE_COMMENT:
                        yield self._parse_and_link_torch_function_decl(
                            cpp_func, trivia)
                        break
Esempio n. 2
0
parser = argparse.ArgumentParser(description="Generate ORT ATen operations")
parser.add_argument(
    "--ops_module", type=str, help="Python module containing the Onnx Operation signature and list of ops to map"
)
parser.add_argument("--output_file", default=None, type=str, help="Output file [default to std out]")
parser.add_argument("--header_file", type=str, help="Header file which contains ATen / Pytorch operation signature")
parser.add_argument(
    "--custom_ops", action="store_true", help="Whether we are generating code for custom ops or native operation"
)

args = parser.parse_args()
loader = SourceFileLoader("", args.ops_module)
ops_module = types.ModuleType(loader.name)
loader.exec_module(ops_module)

ortgen = ORTGen(
    ops_module.ops,
    type_promotion_ops=ops_module.type_promotion_ops,
    custom_ops=args.custom_ops,
    aten_output_type=ops_module.aten_output_type,
)

regdecs_path = args.header_file
print(f"INFO: Using RegistrationDeclarations from: {regdecs_path}")
output = sys.stdout
if args.output_file:
    output = open(args.output_file, "wt")

with CPPParser(regdecs_path) as parser, SourceWriter(output) as writer:
    ortgen.run(parser, writer)