Ejemplo n.º 1
0
 def __init__(self, is_training: bool):
     """
     初始化
     :param is_training: True: training 状态创建 component; False: 非training状态创建 component
     """
     self._is_training = is_training
     self._registry = Registry()
Ejemplo n.º 2
0
    def register(cls,
                 name_space: str,
                 typename: str = None,
                 is_allowed_exist: bool = False) -> T:
        """
        用作在类或函数上的装饰器
        :param typename: 注册的类或者函数的类型名称, 如果为 None 或者 "", 那么, 将会默认使用类或者函数的名字作为名字,
                         在配置文件中直接使用类名字或者函数名字即可。
        :param name_space: 注册的类或者函数的 name space
        :param is_allowed_exist: True: 允许名字重复,那么,后面的名字会将前面的名字覆盖, 正常来讲不应该出现这样的设置;
                                 False: 不允许名字重复, 如果出现重复,自己定义的名字需要进行修改
        :return:
        """
        register = Registry()

        def add_to_registry(registered_class: Type[T]):

            name = typename or registered_class.__name__

            register.register_class(cls=registered_class,
                                    name_space=name_space,
                                    name=name,
                                    is_allowed_exist=is_allowed_exist)
            return registered_class

        return add_to_registry
Ejemplo n.º 3
0
def test_component_evaluate_factory():
    Registry().clear_objects()

    config_json_file_path = "data/easytext/tests/component/training.json"
    config_json_file_path = os.path.join(ROOT_PATH, config_json_file_path)
    with open(config_json_file_path, encoding="utf-8") as f:

        param_dict = json.load(f, object_pairs_hook=OrderedDict)

    factory = ComponentFactory(is_training=False)

    parsed_dict = factory.create(config=param_dict)

    my_component = parsed_dict["my_component"]

    ASSERT.assertEqual("evaluate_3", my_component.value)
Ejemplo n.º 4
0
def test_component_with_object():
    """
    测试,当 component 构建的时候,某个参数是 object
    :return:
    """
    Registry().clear_objects()
    config_json_file_path = "data/easytext/tests/component/component_with_obj.json"
    config_json_file_path = os.path.join(ROOT_PATH, config_json_file_path)
    with open(config_json_file_path, encoding="utf-8") as f:
        param_dict = json.load(f, object_pairs_hook=OrderedDict)

    factory = ComponentFactory(is_training=False)

    parsed_dict = factory.create(config=param_dict)

    my_obj = parsed_dict["my_obj"]

    ASSERT.assertEqual(10, my_obj.value)

    my_component: _ModelWithObjParam = parsed_dict["my_component"]

    ASSERT.assertEqual(4, my_component.sub_model.in_features)
    ASSERT.assertEqual(2, my_component.sub_model.out_features)

    ASSERT.assertTrue(id(my_obj) == id(my_component.customer_obj))

    another_component: _ModelWithObjParam = parsed_dict["another_component"]

    ASSERT.assertTrue(id(my_component) != id(another_component))

    another_obj: _CustomerObj = parsed_dict["another_obj"]
    ASSERT.assertTrue(id(another_obj) == id(another_component.customer_obj))

    ASSERT.assertEqual(20, another_obj.value)

    dict_param_component: _DictParamComponent = parsed_dict[
        "dict_param_component"]
    ASSERT.assertTrue(
        id(dict_param_component.curstomer_obj) == id(another_obj))

    ASSERT.assertEqual(1, dict_param_component.dict_value["a"])
    ASSERT.assertEqual(2, dict_param_component.dict_value["b"])
    ASSERT.assertEqual(30, dict_param_component.dict_value["c_obj"].value)

    my_object = parsed_dict["my_object"]
    ASSERT.assertEqual("my_test_value", my_object)
Ejemplo n.º 5
0
def test_component_factory():
    Registry().clear_objects()

    model_json_file_path = "data/easytext/tests/component/model.json"
    model_json_file_path = os.path.join(ROOT_PATH, model_json_file_path)
    with open(model_json_file_path, encoding="utf-8") as f:
        config = json.load(f, object_pairs_hook=OrderedDict)

    factory = ComponentFactory(is_training=True)

    parserd_dict = factory.create(config=config)

    model = parserd_dict["model"]

    ASSERT.assertTrue(model.linear is not None)
    ASSERT.assertEqual((2, 4),
                       (model.linear.in_features, model.linear.out_features))
Ejemplo n.º 6
0
def test_default_typename():
    """
    测试,当 component 构建的时候,某个参数是 object
    :return:
    """
    Registry().clear_objects()
    config_json_file_path = "data/easytext/tests/component/default_typename.json"
    config_json_file_path = os.path.join(ROOT_PATH, config_json_file_path)
    with open(config_json_file_path, encoding="utf-8") as f:
        param_dict = json.load(f, object_pairs_hook=OrderedDict)

    factory = ComponentFactory(is_training=False)

    parsed_dict = factory.create(config=param_dict)

    default_typename = parsed_dict["default_typename"]

    ASSERT.assertEqual(10, default_typename.value)
Ejemplo n.º 7
0
class ComponentFactory:
    """
    Component Factory
    """
    def __init__(self, is_training: bool):
        """
        初始化
        :param is_training: True: training 状态创建 component; False: 非training状态创建 component
        """
        self._is_training = is_training
        self._registry = Registry()

    def _create_by_object(self, path: str, param_dict: OrderedDict):
        """
        通过 object 来得到 object, 因为 object 是之前创建好的,直接过去就好
        :param path:
        :param param_dict:
        :return:
        """
        object_path = param_dict[ComponentBuiltinKey.OBJECT]
        return self._registry.find_object(object_path)

    def _create_by_type(self, path: str, param_dict: OrderedDict):
        """
        通过 type 和 name space 创建 object
        :param path:
        :param param_dict:
        :return:
        """
        component_type = param_dict.pop(ComponentBuiltinKey.TYPENAME)
        name_space = param_dict.pop(ComponentBuiltinKey.NAME_SPACE)

        cls = self._registry.find_class(name=component_type,
                                        name_space=name_space)

        if cls is None:
            raise RuntimeError(f"{name_space}:{component_type} 没有被注册")

        for param_name, param_value in param_dict.items():

            if isinstance(param_value, OrderedDict):
                sub_path = f"{path}.{param_name}"
                v_obj = self._create(path=sub_path, param_dict=param_value)
                param_dict[param_name] = v_obj
            else:
                # 不用处理
                pass

        # 增加 is_training 参数
        need_is_training_parameter = False
        if inspect.isclass(cls):
            if issubclass(cls, Component):
                need_is_training_parameter = True
            else:
                # 非 Component 类,不做任何处理
                pass
        elif inspect.isfunction(cls):
            if ComponentBuiltinKey.IS_TRAINING in inspect.getfullargspec(
                    cls).args:
                need_is_training_parameter = True

        else:
            raise RuntimeError(f"{cls} 错误! 应该是 函数 或者是 类的静态函数, 不能是类函数或者成员函数")

        if need_is_training_parameter and (ComponentBuiltinKey.IS_TRAINING
                                           not in param_dict):
            param_dict[ComponentBuiltinKey.IS_TRAINING] = self._is_training

        try:
            obj = cls(**param_dict)
        except TypeError as type_error:
            logging.fatal(f"Exception: {type_error} for {cls}")
            logging.fatal(traceback.format_exc())
            raise type_error

        self._registry.register_object(name=path, obj=obj)
        return obj

    def _create_by_raw_dict(self, path: str, param_dict: OrderedDict):
        """
        没有 type 和 name space, 该字典就是参数
        :param path:
        :param param_dict:
        :return:
        """

        for param_name, param_value in param_dict.items():

            if isinstance(param_value, OrderedDict):
                sub_path = f"{path}.{param_name}"
                v_obj = self._create(path=sub_path, param_dict=param_value)
                param_dict[param_name] = v_obj
            else:
                # 不用处理
                pass

        return param_dict

    def _create(self, path: str, param_dict: OrderedDict):

        param_dict: OrderedDict = param_dict

        if ComponentBuiltinKey.OBJECT in param_dict:
            return self._create_by_object(path=path, param_dict=param_dict)

        elif ComponentBuiltinKey.TYPENAME in param_dict and ComponentBuiltinKey.NAME_SPACE in param_dict:
            return self._create_by_type(path=path, param_dict=param_dict)

        elif ComponentBuiltinKey.TYPENAME in param_dict and ComponentBuiltinKey.NAME_SPACE not in param_dict:
            raise RuntimeError(
                f"构建 {path} 错误, "
                f"{ComponentBuiltinKey.TYPENAME} 与 {ComponentBuiltinKey.NAME_SPACE} 必须同时出现"
            )

        elif ComponentBuiltinKey.TYPENAME not in param_dict and ComponentBuiltinKey.NAME_SPACE in param_dict:
            raise RuntimeError(
                f"构建 {path} 错误, "
                f"{ComponentBuiltinKey.TYPENAME} 与 {ComponentBuiltinKey.NAME_SPACE} 必须同时出现"
            )

        else:
            # 这种情况是指,参数就是一个字典
            return self._create_by_raw_dict(path=path, param_dict=param_dict)

    def create(self, config: OrderedDict):
        """
        创建对象工厂
        :param config: config 字典, 是 OrderedDict, 其中的 key 会按照顺序执行
        :return:
        """

        assert isinstance(
            config,
            OrderedDict), f"param_dict type: {type(config)} 不是 OrderedDict"

        parsed_config = copy.deepcopy(config)

        for obj_name, param_dict in parsed_config.items():

            if isinstance(param_dict, OrderedDict):
                parsed_config[obj_name] = self._create(obj_name,
                                                       param_dict=param_dict)
            else:
                # 对于非字典的根节点下的是基础类型,不用做处理
                pass
        return parsed_config