def setUp(self) -> None: self.native_functions: List[NativeFunction] = [] self.backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict) yaml_entry = """ - func: op(Tensor self) -> Tensor dispatch: CompositeExplicitAutograd: op autogen: op.out """ es = yaml.load(yaml_entry, Loader=LineLoader) self.one_return_func, m = NativeFunction.from_yaml(es[0], loc=Location( __file__, 1), valid_tags=set()) BackendIndex.grow_index(self.backend_indices, m) self.two_returns_func, two_returns_backend_index = NativeFunction.from_yaml( { "func": "op_2() -> (Tensor, Tensor)", "dispatch": { "CPU": "kernel_1" }, "autogen": "op_2.out", }, loc=torchgen.model.Location(__file__, 1), valid_tags=set(), ) BackendIndex.grow_index(self.backend_indices, two_returns_backend_index)
def setUp(self) -> None: self.op_1_native_function, op_1_backend_index = NativeFunction.from_yaml( { "func": "op_1() -> bool", "dispatch": { "CPU": "kernel_1" } }, loc=torchgen.model.Location(__file__, 1), valid_tags=set(), ) self.op_2_native_function, op_2_backend_index = NativeFunction.from_yaml( { "func": "op_2() -> bool", "dispatch": { "CPU": "kernel_2", "QuantizedCPU": "custom::kernel_3" }, }, loc=torchgen.model.Location(__file__, 1), valid_tags=set(), ) backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = { DispatchKey.CPU: {}, DispatchKey.QuantizedCPU: {}, } BackendIndex.grow_index(backend_indices, op_1_backend_index) BackendIndex.grow_index(backend_indices, op_2_backend_index) self.backend_indices = { k: BackendIndex( dispatch_key=k, use_out_as_primary=True, external=False, symint=False, device_guard=False, index=backend_indices[k], ) for k in backend_indices }
def test_custom_namespace_selected_correctly(self): yaml_config = """ operators: aten::add.int: is_used_for_training: No is_root_operator: Yes include_all_overloads: No custom::add: is_used_for_training: Yes is_root_operator: No include_all_overloads: Yes """ selector = SelectiveBuilder.from_yaml_str(yaml_config) native_function, _ = NativeFunction.from_yaml( {"func": "custom::add() -> Tensor"}, loc=Location(__file__, 1), valid_tags=set(), ) self.assertTrue(selector.is_native_function_selected(native_function))