def view(self, groups: Sequence[NativeFunctionsViewGroup], backend_index: BackendIndex) -> str: if not groups: return "" generated_type_variants = [] for g in groups: with native_function_manager(g): assert is_supported(g) assert isinstance(g, NativeFunctionsViewGroup) generated_type_variant = self.view_op_generator( g, backend_index) generated_type_variants.append(generated_type_variant) op_name = config.func_name_base_str(groups[0]) body = "\n".join(generated_type_variants) generated = f""" REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::{op_name}, aten_{op_name}, [](Node* n) -> SROperator {{ {body} LogAndDumpSchema(n); return nullptr; }}); """ return generated
def view_op_generator(self, g: NativeFunctionsViewGroup, backend_index: BackendIndex) -> str: schema = str(g.view.func) op_name = config.func_name_base_str(g) populated_argument = generate_arg_extraction(g.view.func) functional_variant_call = generate_call_to_view_ops(g, backend_index) generated = f""" if (n->matches(torch::schema("aten::{schema}"))) {{ return [](ProcessedNode* p_node) {{ {populated_argument} p_node->Output(0) = {functional_variant_call}; }}; }}""" return generated
def group_functions_by_op_name( grouped_native_functions: Sequence[NativeGroupT], ) -> Sequence[Sequence[NativeGroupT]]: if not grouped_native_functions: return [] groups = [] def is_supported( g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool: with native_function_manager(g): return generator.is_supported(g) eligible_ops = (g for g in grouped_native_functions if is_supported(g)) groups = [ list(group) for k, group in (itertools.groupby( eligible_ops, key=lambda g: config.func_name_base_str(g), )) ] return groups