Esempio n. 1
0
    def register(cls,
                 name_space: str,
                 typename: str = None,
                 is_allowed_exist: bool = False) -> T:
        """
        用作在类或函数上的装饰器
        :param typename: 注册的类或者函数的类型名称, 如果为 None 或者 "", 那么, 将会默认使用类或者函数的名字作为名字,
                         在配置文件中直接使用类名字或者函数名字即可。
        :param name_space: 注册的类或者函数的 name space
        :param is_allowed_exist: True: 允许名字重复,那么,后面的名字会将前面的名字覆盖, 正常来讲不应该出现这样的设置;
                                 False: 不允许名字重复, 如果出现重复,自己定义的名字需要进行修改
        :return:
        """
        register = Registry()

        def add_to_registry(registered_class: Type[T]):

            name = typename or registered_class.__name__

            register.register_class(cls=registered_class,
                                    name_space=name_space,
                                    name=name,
                                    is_allowed_exist=is_allowed_exist)
            return registered_class

        return add_to_registry
Esempio n. 2
0
 def __init__(self, is_training: bool):
     """
     初始化
     :param is_training: True: training 状态创建 component; False: 非training状态创建 component
     """
     self._is_training = is_training
     self._registry = Registry()
Esempio n. 3
0
def test_component_evaluate_factory():
    Registry().clear_objects()

    config_json_file_path = "data/easytext/tests/component/training.json"
    config_json_file_path = os.path.join(ROOT_PATH, config_json_file_path)
    with open(config_json_file_path, encoding="utf-8") as f:

        param_dict = json.load(f, object_pairs_hook=OrderedDict)

    factory = ComponentFactory(is_training=False)

    parsed_dict = factory.create(config=param_dict)

    my_component = parsed_dict["my_component"]

    ASSERT.assertEqual("evaluate_3", my_component.value)
Esempio n. 4
0
def test_component_with_object():
    """
    测试,当 component 构建的时候,某个参数是 object
    :return:
    """
    Registry().clear_objects()
    config_json_file_path = "data/easytext/tests/component/component_with_obj.json"
    config_json_file_path = os.path.join(ROOT_PATH, config_json_file_path)
    with open(config_json_file_path, encoding="utf-8") as f:
        param_dict = json.load(f, object_pairs_hook=OrderedDict)

    factory = ComponentFactory(is_training=False)

    parsed_dict = factory.create(config=param_dict)

    my_obj = parsed_dict["my_obj"]

    ASSERT.assertEqual(10, my_obj.value)

    my_component: _ModelWithObjParam = parsed_dict["my_component"]

    ASSERT.assertEqual(4, my_component.sub_model.in_features)
    ASSERT.assertEqual(2, my_component.sub_model.out_features)

    ASSERT.assertTrue(id(my_obj) == id(my_component.customer_obj))

    another_component: _ModelWithObjParam = parsed_dict["another_component"]

    ASSERT.assertTrue(id(my_component) != id(another_component))

    another_obj: _CustomerObj = parsed_dict["another_obj"]
    ASSERT.assertTrue(id(another_obj) == id(another_component.customer_obj))

    ASSERT.assertEqual(20, another_obj.value)

    dict_param_component: _DictParamComponent = parsed_dict[
        "dict_param_component"]
    ASSERT.assertTrue(
        id(dict_param_component.curstomer_obj) == id(another_obj))

    ASSERT.assertEqual(1, dict_param_component.dict_value["a"])
    ASSERT.assertEqual(2, dict_param_component.dict_value["b"])
    ASSERT.assertEqual(30, dict_param_component.dict_value["c_obj"].value)

    my_object = parsed_dict["my_object"]
    ASSERT.assertEqual("my_test_value", my_object)
Esempio n. 5
0
def test_component_factory():
    Registry().clear_objects()

    model_json_file_path = "data/easytext/tests/component/model.json"
    model_json_file_path = os.path.join(ROOT_PATH, model_json_file_path)
    with open(model_json_file_path, encoding="utf-8") as f:
        config = json.load(f, object_pairs_hook=OrderedDict)

    factory = ComponentFactory(is_training=True)

    parserd_dict = factory.create(config=config)

    model = parserd_dict["model"]

    ASSERT.assertTrue(model.linear is not None)
    ASSERT.assertEqual((2, 4),
                       (model.linear.in_features, model.linear.out_features))
Esempio n. 6
0
def test_default_typename():
    """
    测试,当 component 构建的时候,某个参数是 object
    :return:
    """
    Registry().clear_objects()
    config_json_file_path = "data/easytext/tests/component/default_typename.json"
    config_json_file_path = os.path.join(ROOT_PATH, config_json_file_path)
    with open(config_json_file_path, encoding="utf-8") as f:
        param_dict = json.load(f, object_pairs_hook=OrderedDict)

    factory = ComponentFactory(is_training=False)

    parsed_dict = factory.create(config=param_dict)

    default_typename = parsed_dict["default_typename"]

    ASSERT.assertEqual(10, default_typename.value)