示例#1
0
文件: gen.py 项目: dreiss/pytorch
def parse_native_yaml(path: str) -> ParsedYaml:
    global _GLOBAL_PARSE_NATIVE_YAML_CACHE
    if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
        with open(path, 'r') as f:
            es = yaml.load(f, Loader=LineLoader)
        assert isinstance(es, list)
        rs: List[NativeFunction] = []
        bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict)
        for e in es:
            assert isinstance(e.get('__line__'), int), e
            loc = Location(path, e['__line__'])
            funcs = e.get('func')
            with context(lambda: f'in {loc}:\n  {funcs}'):
                func, m = NativeFunction.from_yaml(e, loc)
                rs.append(func)
                BackendIndex.grow_index(bs, m)
        error_check_native_functions(rs)
        # Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet.
        indices: Dict[DispatchKey, BackendIndex] = defaultdict(lambda: BackendIndex(
            dispatch_key=DispatchKey.Undefined, use_out_as_primary=True, external=False, index={}))
        for k, v in bs.items():
            # All structured in-tree operators are implemented in terms of their out operator.
            indices[k] = BackendIndex(dispatch_key=k, use_out_as_primary=True, external=False, index=v)
        _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = ParsedYaml(rs, indices)

    return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path]
示例#2
0
def requires_backend_wrapper(f: NativeFunction,
                             backend_index: BackendIndex) -> bool:
    requires_lowering = not f.has_composite_kernel and not has_autogenerated_composite_kernel(
        f)
    has_xla_lowering = backend_index.has_kernel(f)
    in_denylist = any(
        [re.match(frx, str(f.func.name)) for frx in _FN_DENYLIST_REGEX])
    return not in_denylist and (requires_lowering or has_xla_lowering)
示例#3
0
def gen_unstructured(f: NativeFunction,
                     backend_index: BackendIndex) -> Optional[str]:
    sig = kernel_signature(f, backend_index)
    metadata = backend_index.get_kernel(f)
    if metadata is None:
        return None
    if "legacy::" in metadata.kernel:
        return None
    else:
        prefix = '' if backend_index.external else 'TORCH_API '
        return f"{prefix}{sig.decl(name=metadata.kernel)};"
示例#4
0
def gen_structured(g: NativeFunctionsGroup,
                   backend_index: BackendIndex) -> List[str]:
    meta_name = meta.name(g)
    out_args = structured.impl_arguments(g)
    metadata = backend_index.get_kernel(g)
    if metadata is None:
        return []
    prefix = '' if backend_index.external else 'TORCH_API '
    return [
        f"""\
struct {prefix}structured_{metadata.kernel} : public at::meta::{meta_name} {{
void impl({', '.join(a.decl() for a in out_args)});
}};
"""
    ]
示例#5
0
 def create_backend_index(backend_ops: List[str],
                          dispatch_key: DispatchKey) -> BackendIndex:
     metadata: Dict[OperatorName, BackendMetadata] = {}
     for op in backend_ops:
         op_name = OperatorName.parse(op)
         assert op_name in native_functions_map, f"Found an invalid operator name: {op_name}"
         # See Note [External Backends Follow Dispatcher API]
         kernel_name = dispatcher.name(native_functions_map[op_name].func)
         # TODO: allow structured external backends later.
         m = BackendMetadata(kernel=kernel_name, structured=False)
         metadata[op_name] = m
     # TODO: currently hardcoding the fact that XLA implements out/inplace in terms of functional ops,
     # this should eventually be toggleable per-backend.
     return BackendIndex(dispatch_key=dispatch_key,
                         use_out_as_primary=False,
                         external=True,
                         index=metadata)
示例#6
0
 def create_backend_index(backend_ops: List[str], dispatch_key: DispatchKey,
                          *, use_out_as_primary: bool,
                          use_device_guard: bool) -> BackendIndex:
     metadata: Dict[OperatorName, BackendMetadata] = {}
     for op in backend_ops:
         op_name = OperatorName.parse(op)
         assert op_name in native_functions_map, f"Found an invalid operator name: {op_name}"
         # See Note [External Backends Follow Dispatcher API]
         kernel_name = dispatcher.name(native_functions_map[op_name].func)
         # TODO: allow structured external backends later.
         m = BackendMetadata(kernel=kernel_name, structured=False)
         metadata[op_name] = m
     return BackendIndex(dispatch_key=dispatch_key,
                         use_out_as_primary=use_out_as_primary,
                         external=True,
                         device_guard=use_device_guard,
                         index=metadata)
示例#7
0
def compute_native_function_declaration(
        g: Union[NativeFunctionsGroup,
                 NativeFunction], backend_index: BackendIndex) -> List[str]:
    metadata = backend_index.get_kernel(g)
    if isinstance(g, NativeFunctionsGroup):
        if metadata is not None and metadata.structured:
            if backend_index.external:
                # Structured hasn't been tested with external backends yet.
                raise AssertionError(
                    "Structured external backend functions are not implemented yet."
                )
            else:
                return gen_structured(g, backend_index)
        else:
            return list(
                mapMaybe(lambda f: gen_unstructured(f, backend_index),
                         g.functions()))
    else:
        x = gen_unstructured(g, backend_index)
        return [] if x is None else [x]