示例#1
0
文件: generator.py 项目: Mu-L/pytorch
    def view(self, groups: Sequence[NativeFunctionsViewGroup],
             backend_index: BackendIndex) -> str:
        if not groups:
            return ""
        generated_type_variants = []
        for g in groups:
            with native_function_manager(g):
                assert is_supported(g)
                assert isinstance(g, NativeFunctionsViewGroup)
                generated_type_variant = self.view_op_generator(
                    g, backend_index)
                generated_type_variants.append(generated_type_variant)
        op_name = config.func_name_base_str(groups[0])
        body = "\n".join(generated_type_variants)
        generated = f"""
REGISTER_NATIVE_OPERATOR_FUNCTOR(
    aten::{op_name},
    aten_{op_name},
    [](Node* n) -> SROperator {{
      {body}
      LogAndDumpSchema(n);
      return nullptr;
    }});
"""
        return generated
示例#2
0
 def view_op_generator(self, g: NativeFunctionsViewGroup,
                       backend_index: BackendIndex) -> str:
     schema = str(g.view.func)
     op_name = config.func_name_base_str(g)
     populated_argument = generate_arg_extraction(g.view.func)
     functional_variant_call = generate_call_to_view_ops(g, backend_index)
     generated = f"""
   if (n->matches(torch::schema("aten::{schema}"))) {{
     return [](ProcessedNode* p_node) {{
       {populated_argument}
         p_node->Output(0) = {functional_variant_call};
     }};
   }}"""
     return generated
def group_functions_by_op_name(
    grouped_native_functions: Sequence[NativeGroupT],
) -> Sequence[Sequence[NativeGroupT]]:
    if not grouped_native_functions:
        return []
    groups = []

    def is_supported(
            g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool:
        with native_function_manager(g):
            return generator.is_supported(g)

    eligible_ops = (g for g in grouped_native_functions if is_supported(g))
    groups = [
        list(group) for k, group in (itertools.groupby(
            eligible_ops,
            key=lambda g: config.func_name_base_str(g),
        ))
    ]

    return groups