def get_selector_from_legacy_operator_selection_list( selected_op_list_path: str, ) -> Any: with open(selected_op_list_path, "r") as f: # strip out the overload part # It's only for legacy config - do NOT copy this code! selected_op_list = { opname.split(".", 1)[0] for opname in yaml.load(f, Loader=YamlLoader) } # Internal build doesn't use this flag any more. Only used by OSS # build now. Every operator should be considered a root operator # (hence generating unboxing code for it, which is consistent with # the current behaviour), and also be considered as used for # training, since OSS doesn't support training on mobile for now. # is_root_operator = True is_used_for_training = True from torchgen.selective_build.selector import SelectiveBuilder selector = SelectiveBuilder.from_legacy_op_registration_allow_list( selected_op_list, is_root_operator, is_used_for_training, ) return selector
def test_selector_factory(self): yaml_config_v1 = """ debug_info: - model1@v100 - model2@v51 operators: aten::add: is_used_for_training: No is_root_operator: Yes include_all_overloads: Yes aten::add.int: is_used_for_training: Yes is_root_operator: No include_all_overloads: No aten::mul.int: is_used_for_training: Yes is_root_operator: No include_all_overloads: No """ yaml_config_v2 = """ debug_info: - model1@v100 - model2@v51 operators: aten::sub: is_used_for_training: No is_root_operator: Yes include_all_overloads: No debug_info: - model1@v100 aten::sub.int: is_used_for_training: Yes is_root_operator: No include_all_overloads: No """ yaml_config_all = "include_all_operators: Yes" yaml_config_invalid = "invalid:" selector1 = SelectiveBuilder.from_yaml_str(yaml_config_v1) self.assertTrue(selector1.is_operator_selected("aten::add")) self.assertTrue(selector1.is_operator_selected("aten::add.int")) # Overload name is not used for checking in v1. self.assertTrue(selector1.is_operator_selected("aten::add.float")) def gen(): return SelectiveBuilder.from_yaml_str(yaml_config_invalid) self.assertRaises(Exception, gen) selector_all = SelectiveBuilder.from_yaml_str(yaml_config_all) self.assertTrue(selector_all.is_operator_selected("aten::add")) self.assertTrue(selector_all.is_operator_selected("aten::sub")) self.assertTrue(selector_all.is_operator_selected("aten::sub.int")) self.assertTrue(selector_all.is_kernel_dtype_selected("add_kernel", "int32")) selector2 = SelectiveBuilder.from_yaml_str(yaml_config_v2) self.assertFalse(selector2.is_operator_selected("aten::add")) self.assertTrue(selector2.is_operator_selected("aten::sub")) self.assertTrue(selector2.is_operator_selected("aten::sub.int")) selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list( ["aten::add", "aten::add.int", "aten::mul.int"], False, False, ) self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.float")) self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add")) self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.int")) self.assertFalse(selector_legacy_v1.is_operator_selected("aten::sub")) self.assertFalse(selector_legacy_v1.is_root_operator("aten::add")) self.assertFalse( selector_legacy_v1.is_operator_selected_for_training("aten::add") ) selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list( ["aten::add", "aten::add.int", "aten::mul.int"], True, False, ) self.assertTrue(selector_legacy_v1.is_root_operator("aten::add")) self.assertFalse( selector_legacy_v1.is_operator_selected_for_training("aten::add") ) self.assertTrue(selector_legacy_v1.is_root_operator("aten::add.float")) self.assertFalse( selector_legacy_v1.is_operator_selected_for_training("aten::add.float") ) selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list( ["aten::add", "aten::add.int", "aten::mul.int"], False, True, ) self.assertFalse(selector_legacy_v1.is_root_operator("aten::add")) self.assertTrue( selector_legacy_v1.is_operator_selected_for_training("aten::add") ) self.assertFalse(selector_legacy_v1.is_root_operator("aten::add.float")) self.assertTrue( selector_legacy_v1.is_operator_selected_for_training("aten::add.float") )