Beispiel #1
0
def gen_structured(g: NativeFunctionsGroup) -> List[str]:
    # only out has dispatch
    meta_name = meta.name(g)
    rs = []
    seen: Set[Any] = set()
    out_args = structured.impl_arguments(g)
    for k, n in g.out.dispatch.items():
        if n in seen:
            continue
        if not is_structured_dispatch_key(k):
            continue
        seen.add(n)
        rs.append(f"""\
struct TORCH_API structured_{n} : public at::meta::{meta_name} {{
void impl({', '.join(a.decl() for a in out_args)});
}};
""")

    seen = set()
    for f in g.functions():
        returns_type = native.returns_type(f.func.returns).cpp_type()
        args = native.arguments(f.func)
        for k, n in f.dispatch.items():
            if n in seen:
                continue
            if is_structured_dispatch_key(k):
                continue
            seen.add(n)
            args_str = ', '.join(a.decl() for a in args)
            rs.append(f"TORCH_API {returns_type} {n}({args_str});")

    return rs
Beispiel #2
0
 def flatten_pre_group(
     d: Dict[SchemaKind, NativeFunction]
 ) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
     r = NativeFunctionsGroup.from_dict(d)
     if r is None:
         return list(d.values())
     else:
         return [r]
Beispiel #3
0
    def gen_structured(self, g: NativeFunctionsGroup) -> List[str]:
        metadata = self.backend_index.get_kernel(g)
        if self.backend_index.dispatch_key == DispatchKey.Meta:
            assert not self.backend_index.has_kernel(g.out), \
                "Do not explicitly specify Meta dispatch key on structured " \
                "functions, they will be automatically generated for you"
        elif self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd:
            assert not self.backend_index.has_kernel(g.out), \
                "Do not explicitly specify CompositeExplicitAutograd dispatch key on structured " \
                "functions, they will be automatically generated for you"
        elif metadata is None or not metadata.structured:
            return list(
                mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions()))

        structured_gen = StructuredRegisterDispatchKey(
            self.backend_index, self.target, self.selector, self.rocm,
            self.cpp_namespace, self.class_method_name, g)
        return list(mapMaybe(structured_gen.gen_one, g.functions()))
Beispiel #4
0
    def gen_structured(self, g: NativeFunctionsGroup) -> List[str]:
        if self.dispatch_key == DispatchKey.Meta:
            assert self.dispatch_key not in g.out.dispatch, \
                "Do not explicitly specify Meta dispatch key on structured " \
                "functions, they will be automatically generated for you"
        elif self.dispatch_key == DispatchKey.CompositeExplicitAutograd:
            assert self.dispatch_key not in g.out.dispatch, \
                "Do not explicitly specify CompositeExplicitAutograd dispatch key on structured " \
                "functions, they will be automatically generated for you"
        elif not is_structured_dispatch_key(self.dispatch_key):
            return list(mapMaybe(self.gen_unstructured, g.functions()))
        elif self.dispatch_key not in g.out.dispatch:
            return []

        structured_gen = StructuredRegisterDispatchKey(self.dispatch_key,
                                                       self.target,
                                                       self.selector,
                                                       self.rocm, g)
        return list(mapMaybe(structured_gen.gen_one, g.functions()))