Пример #1
0
 def create_backend_index(
     backend_ops: List[str],
     dispatch_key: DispatchKey,
     *,
     use_out_as_primary: bool,
     use_device_guard: bool,
 ) -> BackendIndex:
     metadata: Dict[OperatorName, BackendMetadata] = {}
     for op in backend_ops:
         op_name = OperatorName.parse(op)
         assert (op_name in native_functions_map
                 ), f"Found an invalid operator name: {op_name}"
         # See Note [External Backends Follow Dispatcher API]
         kernel_name = dispatcher.name(native_functions_map[op_name].func)
         # TODO: allow structured external backends later.
         m = BackendMetadata(kernel=kernel_name,
                             structured=False,
                             cpp_namespace=cpp_namespace)
         metadata[op_name] = m
     return BackendIndex(
         dispatch_key=dispatch_key,
         use_out_as_primary=use_out_as_primary,
         external=True,
         symint=True,  # TODO: make this configurable
         device_guard=use_device_guard,
         index=metadata,
     )
Пример #2
0
    def setUp(self) -> None:
        self.op_1_native_function, op_1_backend_index = NativeFunction.from_yaml(
            {
                "func": "op_1() -> bool",
                "dispatch": {
                    "CPU": "kernel_1"
                }
            },
            loc=torchgen.model.Location(__file__, 1),
            valid_tags=set(),
        )
        self.op_2_native_function, op_2_backend_index = NativeFunction.from_yaml(
            {
                "func": "op_2() -> bool",
                "dispatch": {
                    "CPU": "kernel_2",
                    "QuantizedCPU": "custom::kernel_3"
                },
            },
            loc=torchgen.model.Location(__file__, 1),
            valid_tags=set(),
        )

        backend_indices: Dict[DispatchKey,
                              Dict[OperatorName, BackendMetadata]] = {
                                  DispatchKey.CPU: {},
                                  DispatchKey.QuantizedCPU: {},
                              }
        BackendIndex.grow_index(backend_indices, op_1_backend_index)
        BackendIndex.grow_index(backend_indices, op_2_backend_index)
        self.backend_indices = {
            k: BackendIndex(
                dispatch_key=k,
                use_out_as_primary=True,
                external=False,
                symint=False,
                device_guard=False,
                index=backend_indices[k],
            )
            for k in backend_indices
        }