def __init__(self, is_training: bool): """ 初始化 :param is_training: True: training 状态创建 component; False: 非training状态创建 component """ self._is_training = is_training self._registry = Registry()
def register(cls, name_space: str, typename: str = None, is_allowed_exist: bool = False) -> T: """ 用作在类或函数上的装饰器 :param typename: 注册的类或者函数的类型名称, 如果为 None 或者 "", 那么, 将会默认使用类或者函数的名字作为名字, 在配置文件中直接使用类名字或者函数名字即可。 :param name_space: 注册的类或者函数的 name space :param is_allowed_exist: True: 允许名字重复,那么,后面的名字会将前面的名字覆盖, 正常来讲不应该出现这样的设置; False: 不允许名字重复, 如果出现重复,自己定义的名字需要进行修改 :return: """ register = Registry() def add_to_registry(registered_class: Type[T]): name = typename or registered_class.__name__ register.register_class(cls=registered_class, name_space=name_space, name=name, is_allowed_exist=is_allowed_exist) return registered_class return add_to_registry
def test_component_evaluate_factory(): Registry().clear_objects() config_json_file_path = "data/easytext/tests/component/training.json" config_json_file_path = os.path.join(ROOT_PATH, config_json_file_path) with open(config_json_file_path, encoding="utf-8") as f: param_dict = json.load(f, object_pairs_hook=OrderedDict) factory = ComponentFactory(is_training=False) parsed_dict = factory.create(config=param_dict) my_component = parsed_dict["my_component"] ASSERT.assertEqual("evaluate_3", my_component.value)
def test_component_with_object(): """ 测试,当 component 构建的时候,某个参数是 object :return: """ Registry().clear_objects() config_json_file_path = "data/easytext/tests/component/component_with_obj.json" config_json_file_path = os.path.join(ROOT_PATH, config_json_file_path) with open(config_json_file_path, encoding="utf-8") as f: param_dict = json.load(f, object_pairs_hook=OrderedDict) factory = ComponentFactory(is_training=False) parsed_dict = factory.create(config=param_dict) my_obj = parsed_dict["my_obj"] ASSERT.assertEqual(10, my_obj.value) my_component: _ModelWithObjParam = parsed_dict["my_component"] ASSERT.assertEqual(4, my_component.sub_model.in_features) ASSERT.assertEqual(2, my_component.sub_model.out_features) ASSERT.assertTrue(id(my_obj) == id(my_component.customer_obj)) another_component: _ModelWithObjParam = parsed_dict["another_component"] ASSERT.assertTrue(id(my_component) != id(another_component)) another_obj: _CustomerObj = parsed_dict["another_obj"] ASSERT.assertTrue(id(another_obj) == id(another_component.customer_obj)) ASSERT.assertEqual(20, another_obj.value) dict_param_component: _DictParamComponent = parsed_dict[ "dict_param_component"] ASSERT.assertTrue( id(dict_param_component.curstomer_obj) == id(another_obj)) ASSERT.assertEqual(1, dict_param_component.dict_value["a"]) ASSERT.assertEqual(2, dict_param_component.dict_value["b"]) ASSERT.assertEqual(30, dict_param_component.dict_value["c_obj"].value) my_object = parsed_dict["my_object"] ASSERT.assertEqual("my_test_value", my_object)
def test_component_factory(): Registry().clear_objects() model_json_file_path = "data/easytext/tests/component/model.json" model_json_file_path = os.path.join(ROOT_PATH, model_json_file_path) with open(model_json_file_path, encoding="utf-8") as f: config = json.load(f, object_pairs_hook=OrderedDict) factory = ComponentFactory(is_training=True) parserd_dict = factory.create(config=config) model = parserd_dict["model"] ASSERT.assertTrue(model.linear is not None) ASSERT.assertEqual((2, 4), (model.linear.in_features, model.linear.out_features))
def test_default_typename(): """ 测试,当 component 构建的时候,某个参数是 object :return: """ Registry().clear_objects() config_json_file_path = "data/easytext/tests/component/default_typename.json" config_json_file_path = os.path.join(ROOT_PATH, config_json_file_path) with open(config_json_file_path, encoding="utf-8") as f: param_dict = json.load(f, object_pairs_hook=OrderedDict) factory = ComponentFactory(is_training=False) parsed_dict = factory.create(config=param_dict) default_typename = parsed_dict["default_typename"] ASSERT.assertEqual(10, default_typename.value)
class ComponentFactory: """ Component Factory """ def __init__(self, is_training: bool): """ 初始化 :param is_training: True: training 状态创建 component; False: 非training状态创建 component """ self._is_training = is_training self._registry = Registry() def _create_by_object(self, path: str, param_dict: OrderedDict): """ 通过 object 来得到 object, 因为 object 是之前创建好的,直接过去就好 :param path: :param param_dict: :return: """ object_path = param_dict[ComponentBuiltinKey.OBJECT] return self._registry.find_object(object_path) def _create_by_type(self, path: str, param_dict: OrderedDict): """ 通过 type 和 name space 创建 object :param path: :param param_dict: :return: """ component_type = param_dict.pop(ComponentBuiltinKey.TYPENAME) name_space = param_dict.pop(ComponentBuiltinKey.NAME_SPACE) cls = self._registry.find_class(name=component_type, name_space=name_space) if cls is None: raise RuntimeError(f"{name_space}:{component_type} 没有被注册") for param_name, param_value in param_dict.items(): if isinstance(param_value, OrderedDict): sub_path = f"{path}.{param_name}" v_obj = self._create(path=sub_path, param_dict=param_value) param_dict[param_name] = v_obj else: # 不用处理 pass # 增加 is_training 参数 need_is_training_parameter = False if inspect.isclass(cls): if issubclass(cls, Component): need_is_training_parameter = True else: # 非 Component 类,不做任何处理 pass elif inspect.isfunction(cls): if ComponentBuiltinKey.IS_TRAINING in inspect.getfullargspec( cls).args: need_is_training_parameter = True else: raise RuntimeError(f"{cls} 错误! 应该是 函数 或者是 类的静态函数, 不能是类函数或者成员函数") if need_is_training_parameter and (ComponentBuiltinKey.IS_TRAINING not in param_dict): param_dict[ComponentBuiltinKey.IS_TRAINING] = self._is_training try: obj = cls(**param_dict) except TypeError as type_error: logging.fatal(f"Exception: {type_error} for {cls}") logging.fatal(traceback.format_exc()) raise type_error self._registry.register_object(name=path, obj=obj) return obj def _create_by_raw_dict(self, path: str, param_dict: OrderedDict): """ 没有 type 和 name space, 该字典就是参数 :param path: :param param_dict: :return: """ for param_name, param_value in param_dict.items(): if isinstance(param_value, OrderedDict): sub_path = f"{path}.{param_name}" v_obj = self._create(path=sub_path, param_dict=param_value) param_dict[param_name] = v_obj else: # 不用处理 pass return param_dict def _create(self, path: str, param_dict: OrderedDict): param_dict: OrderedDict = param_dict if ComponentBuiltinKey.OBJECT in param_dict: return self._create_by_object(path=path, param_dict=param_dict) elif ComponentBuiltinKey.TYPENAME in param_dict and ComponentBuiltinKey.NAME_SPACE in param_dict: return self._create_by_type(path=path, param_dict=param_dict) elif ComponentBuiltinKey.TYPENAME in param_dict and ComponentBuiltinKey.NAME_SPACE not in param_dict: raise RuntimeError( f"构建 {path} 错误, " f"{ComponentBuiltinKey.TYPENAME} 与 {ComponentBuiltinKey.NAME_SPACE} 必须同时出现" ) elif ComponentBuiltinKey.TYPENAME not in param_dict and ComponentBuiltinKey.NAME_SPACE in param_dict: raise RuntimeError( f"构建 {path} 错误, " f"{ComponentBuiltinKey.TYPENAME} 与 {ComponentBuiltinKey.NAME_SPACE} 必须同时出现" ) else: # 这种情况是指,参数就是一个字典 return self._create_by_raw_dict(path=path, param_dict=param_dict) def create(self, config: OrderedDict): """ 创建对象工厂 :param config: config 字典, 是 OrderedDict, 其中的 key 会按照顺序执行 :return: """ assert isinstance( config, OrderedDict), f"param_dict type: {type(config)} 不是 OrderedDict" parsed_config = copy.deepcopy(config) for obj_name, param_dict in parsed_config.items(): if isinstance(param_dict, OrderedDict): parsed_config[obj_name] = self._create(obj_name, param_dict=param_dict) else: # 对于非字典的根节点下的是基础类型,不用做处理 pass return parsed_config