Beispiel #1
0
    def gen_structured(self, g: NativeFunctionsGroup) -> List[str]:
        metadata = self.backend_index.get_kernel(g)
        if self.backend_index.dispatch_key == DispatchKey.Meta:
            assert not self.backend_index.has_kernel(g.out), (
                "Do not explicitly specify Meta dispatch key on structured "
                "functions, they will be automatically generated for you")
        elif self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd:
            assert not self.backend_index.has_kernel(g.out), (
                "Do not explicitly specify CompositeExplicitAutograd dispatch key on structured "
                "functions, they will be automatically generated for you")
        elif metadata is None or not metadata.structured:
            return list(
                mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions()))

        structured_gen = StructuredRegisterDispatchKey(
            self.backend_index,
            self.target,
            self.selector,
            self.rocm,
            self.cpp_namespace,
            self.class_method_name,
            self.skip_dispatcher_op_registration,
            g,
        )
        return list(mapMaybe(structured_gen.gen_one, g.functions()))
Beispiel #2
0
 def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
     if isinstance(f, NativeFunctionsGroup):
         g: NativeFunctionsGroup = f
         # Note: We call gen_structured() if the operator is marked structured, regardless of the backend.
         # gen_structured() has special logic to handle auto-generated kernels.
         if g.structured:
             return self.gen_structured(g)
         else:
             return list(
                 mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions())
             )
     elif isinstance(f, NativeFunction):
         r = self.gen_unstructured(f)
         return [] if r is None else [r]
     else:
         assert_never(f)
Beispiel #3
0
def gen_unboxing(
    *,
    native_functions: Sequence[NativeFunction],
    cpu_fm: FileManager,
    selector: SelectiveBuilder,
) -> None:
    def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str:
        return fn.root_name

    selected_op_num: int = len(selector.operators)
    # a best practice threshold of operators to enable sharding
    sharding_threshold: int = 100
    cpu_fm.write_sharded(
        "UnboxingFunctions.cpp",
        native_functions,
        key_fn=key_func,
        env_callable=lambda fn: {
            "definitions": [ComputeUnboxingFunctions(Target.DEFINITION, selector)(fn)]
        },
        num_shards=1 if selected_op_num < sharding_threshold else 5,
        sharded_keys={"definitions"},
    )
    cpu_fm.write(
        "UnboxingFunctions.h",
        lambda: {
            "declarations": list(
                mapMaybe(
                    ComputeUnboxingFunctions(Target.DECLARATION, selector),
                    native_functions,
                )
            ),
        },
    )
    cpu_fm.write_sharded(
        "RegisterCodegenUnboxedKernels.cpp",
        native_functions,
        key_fn=key_func,
        env_callable=lambda fn: {
            "unboxed_ops": [ComputeCodegenUnboxedKernels(selector)(fn)]
        },
        num_shards=1 if selected_op_num < sharding_threshold else 10,
        sharded_keys={"unboxed_ops"},
    )
Beispiel #4
0
def gen_variable_factories(
    out: str, native_yaml_path: str, tags_yaml_path: str, template_path: str
) -> None:
    native_functions = parse_native_yaml(
        native_yaml_path, tags_yaml_path
    ).native_functions
    factory_functions = [fn for fn in native_functions if is_factory_function(fn)]
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    fm.write_with_template(
        "variable_factories.h",
        "variable_factories.h",
        lambda: {
            "generated_comment": "@"
            + f"generated from {fm.template_dir}/variable_factories.h",
            "ops_headers": [
                f"#include <ATen/ops/{fn.root_name}.h>" for fn in factory_functions
            ],
            "function_definitions": list(mapMaybe(process_function, factory_functions)),
        },
    )
Beispiel #5
0
def compute_native_function_declaration(
        g: Union[NativeFunctionsGroup,
                 NativeFunction], backend_index: BackendIndex) -> List[str]:
    metadata = backend_index.get_kernel(g)
    if isinstance(g, NativeFunctionsGroup):
        if metadata is not None and metadata.structured:
            if backend_index.external:
                # Structured hasn't been tested with external backends yet.
                raise AssertionError(
                    "Structured external backend functions are not implemented yet."
                )
            else:
                return gen_structured(g, backend_index)
        else:
            return list(
                mapMaybe(lambda f: gen_unstructured(f, backend_index),
                         g.functions()))
    else:
        x = gen_unstructured(g, backend_index)
        return [] if x is None else [x]
def gen_all_vmap_plumbing(native_functions):
    body = '\n'.join(
        list(mapMaybe(ComputeBatchRulePlumbing(), native_functions)))
    return f"""
Beispiel #7
0
def gen_all_vmap_plumbing(native_functions: Sequence[NativeFunction]) -> str:
    body = "\n".join(
        list(mapMaybe(ComputeBatchRulePlumbing(), native_functions)))
    return f"""