def __init__(self, is_training: bool, config_file_path: str, encoding: str = "utf-8"): with open(config_file_path, encoding=encoding) as f: self._raw_config: OrderedDict = json.load( f, object_pairs_hook=OrderedDict) self._is_training = is_training component_factory = ComponentFactory(self._is_training) self._config = component_factory.create(self._raw_config)
def test_component_evaluate_factory(): Registry().clear_objects() config_json_file_path = "data/easytext/tests/component/training.json" config_json_file_path = os.path.join(ROOT_PATH, config_json_file_path) with open(config_json_file_path, encoding="utf-8") as f: param_dict = json.load(f, object_pairs_hook=OrderedDict) factory = ComponentFactory(is_training=False) parsed_dict = factory.create(config=param_dict) my_component = parsed_dict["my_component"] ASSERT.assertEqual("evaluate_3", my_component.value)
def test_component_with_object(): """ 测试,当 component 构建的时候,某个参数是 object :return: """ Registry().clear_objects() config_json_file_path = "data/easytext/tests/component/component_with_obj.json" config_json_file_path = os.path.join(ROOT_PATH, config_json_file_path) with open(config_json_file_path, encoding="utf-8") as f: param_dict = json.load(f, object_pairs_hook=OrderedDict) factory = ComponentFactory(is_training=False) parsed_dict = factory.create(config=param_dict) my_obj = parsed_dict["my_obj"] ASSERT.assertEqual(10, my_obj.value) my_component: _ModelWithObjParam = parsed_dict["my_component"] ASSERT.assertEqual(4, my_component.sub_model.in_features) ASSERT.assertEqual(2, my_component.sub_model.out_features) ASSERT.assertTrue(id(my_obj) == id(my_component.customer_obj)) another_component: _ModelWithObjParam = parsed_dict["another_component"] ASSERT.assertTrue(id(my_component) != id(another_component)) another_obj: _CustomerObj = parsed_dict["another_obj"] ASSERT.assertTrue(id(another_obj) == id(another_component.customer_obj)) ASSERT.assertEqual(20, another_obj.value) dict_param_component: _DictParamComponent = parsed_dict[ "dict_param_component"] ASSERT.assertTrue( id(dict_param_component.curstomer_obj) == id(another_obj)) ASSERT.assertEqual(1, dict_param_component.dict_value["a"]) ASSERT.assertEqual(2, dict_param_component.dict_value["b"]) ASSERT.assertEqual(30, dict_param_component.dict_value["c_obj"].value) my_object = parsed_dict["my_object"] ASSERT.assertEqual("my_test_value", my_object)
def test_component_factory(): Registry().clear_objects() model_json_file_path = "data/easytext/tests/component/model.json" model_json_file_path = os.path.join(ROOT_PATH, model_json_file_path) with open(model_json_file_path, encoding="utf-8") as f: config = json.load(f, object_pairs_hook=OrderedDict) factory = ComponentFactory(is_training=True) parserd_dict = factory.create(config=config) model = parserd_dict["model"] ASSERT.assertTrue(model.linear is not None) ASSERT.assertEqual((2, 4), (model.linear.in_features, model.linear.out_features))
def test_default_typename(): """ 测试,当 component 构建的时候,某个参数是 object :return: """ Registry().clear_objects() config_json_file_path = "data/easytext/tests/component/default_typename.json" config_json_file_path = os.path.join(ROOT_PATH, config_json_file_path) with open(config_json_file_path, encoding="utf-8") as f: param_dict = json.load(f, object_pairs_hook=OrderedDict) factory = ComponentFactory(is_training=False) parsed_dict = factory.create(config=param_dict) default_typename = parsed_dict["default_typename"] ASSERT.assertEqual(10, default_typename.value)
def __setstate__(self, state): self.__dict__.update(state) component_factory = ComponentFactory(self._is_training) self._config = component_factory.create(self._raw_config.copy())