Exemple #1
0
    def test_source_scope_modification(self):

        source_code = pyparam_parser.read_source_code(self.sample_path /
                                                      "template7.py")
        pyparams = pyparam_parser.get_all_pyparams_from_source_code(
            source_code)
        named_nodes, source_code_module = pyparam_parser.get_source_params_assignments(
            source_code)
        scoped_pyparams = pyparam_parser.add_scope("test", pyparams)

        node_to_config_param = {}
        for named_param, new_param in zip(pyparams, scoped_pyparams):
            node_to_config_param[named_nodes[
                named_param.full_name]] = new_param

        transformer = pyparam_parser.get_render_as_ast_node_transformer(
            node_to_config_param)
        new_root_module = transformer.visit(source_code_module)
        new_source = astor.to_source(
            new_root_module,
            indent_with=pyparam_parser.COMPILED_SOURCE_INDENTATION,
            pretty_source=pyparam_parser.astor_pretty_source_formatter)
        new_scoped_pyparams = pyparam_parser.get_all_pyparams_from_source_code(
            new_source)
        self.assertEqual(new_scoped_pyparams, scoped_pyparams)
Exemple #2
0
    def test_version_check(self):
        source_code = pyparam_parser.read_source_code(self.sample_path /
                                                      "template3.py")
        config = pyparam_parser.read_yaml_file(self.sample_path /
                                               "template3_config.yml")

        new_source_code = pyparam_parser.compile_source_code(
            source_code=source_code, config=config, validate_version=True)

        with open(self.source_code_tmp_path, "w") as file:
            file.write(new_source_code)

        with self.assertRaises(ValueError):
            del config["version"]
            pyparam_parser.compile_source_code(source_code=source_code,
                                               config=config,
                                               validate_version=True)

        with self.assertRaises(ValueError):
            config["version"] = {}
            config["version"]["value"] = "x"
            config["version"]["dtype"] = "str"
            pyparam_parser.compile_source_code(source_code=source_code,
                                               config=config,
                                               validate_version=True)
    def test_include_from_derived_module(self):
        source_code = pyparam_parser.read_source_code(self.sample_path /
                                                      "derive_module.py")
        code = mods.include_modules(source_code, [self.sample_path])

        self.assertTrue('bias: float' in code)
        self.assertTrue('beta: float' in code)
        self.assertTrue("scope='c/matmul'" in code)
        self.assertTrue('matmul2: Module = _pyparam_module__matmul2()' in code)
Exemple #4
0
    def module_source(self, base_path: Path) -> str:
        for folder in base_path.rglob("*"):
            path = folder.parent / Path(self.module_path)
            if path.exists():
                return parser.read_source_code(path)

        raise FileNotFoundError(
            f"Cannot find module: {self.module_path}, search path: {base_path}"
        )
 def test_derive_module(self):
     source_code = pyparam_parser.read_source_code(self.sample_path /
                                                   "derive_module.py")
     code = mods.derive_module(source_code, [self.sample_path])
     self.assertTrue(
         "matmul1: Module = IncludeModule(path='fun_module', scope='a')" in
         code)
     self.assertTrue(
         "matmul2: Module = IncludeModule(path='fun2_module', scope='c')" in
         code)
Exemple #6
0
    def test_source_code_compilation(self):
        source_code = pyparam_parser.read_source_code(self.sample_path /
                                                      "template3.py")
        pyparam_parser.source_to_yaml_config(source_code, self.config_tmp_path)
        config = pyparam_parser.read_yaml_file(Path(self.config_tmp_path))

        new_source_code = pyparam_parser.compile_source_code(
            source_code=source_code, config=config)

        with open(self.source_code_tmp_path, "w") as file:
            file.write(new_source_code)
 def test_include_source(self):
     source_code = pyparam_parser.read_source_code(self.sample_path /
                                                   "base_module_test.py")
     code = mods.include_modules(source_code, [self.sample_path])
     self.assertTrue("bias: float = PyParam(1.1, float, 'matmul')" in code)
     self.assertTrue(
         "PyParams: auto include source of `fun2_module`" in code)
     self.assertTrue("INCLUDE END OF `fun_module`" in code)
     self.assertTrue("  return alpha * matrix @ x + offset" in code)
     self.assertTrue(
         "  alpha: float = PyParam(1.0, float, 'matmul')" in code)
    def test_include_module(self):

        source_code = pyparam_parser.read_source_code(self.sample_path /
                                                      "fun_module_import.py")
        code = mods.include_modules(source_code, [self.sample_path])
        self.assertTrue("class _pyparam_module__matmul2():" in code)
        self.assertTrue("matmul2: Module = _pyparam_module__matmul2()" in code)
        self.assertTrue("    self.matmul = matmul" in code)
        self.assertTrue("offset: float = PyParam(value=1.0, dtype='float', "
                        "scope='b/matmul', desc='')" in code)
        self.assertTrue("offset: float = PyParam(value=1.0, dtype='float', "
                        "scope='a/matmul', desc='')" in code)
Exemple #9
0
    def test_update_pyparams(self):

        source_code = pyparam_parser.read_source_code(self.sample_path /
                                                      "template7.py")
        pyparams = pyparam_parser.get_all_pyparams_from_source_code(
            source_code)

        scoped_pyparams = pyparam_parser.add_scope("test", pyparams)
        new_source = pyparam_parser.update_source_pyparams(
            source_code, scoped_pyparams)
        new_scoped_pyparams = pyparam_parser.get_all_pyparams_from_source_code(
            new_source)
        self.assertEqual(new_scoped_pyparams, scoped_pyparams)
Exemple #10
0
    def test_save_and_load_check_descriptions(self):

        source_code = pyparam_parser.read_source_code(self.sample_path /
                                                      "template7.py")
        pyparams = pyparam_parser.get_all_pyparams_from_source_code(
            source_code)
        pyparam_parser.source_to_yaml_config(source_code, self.config_tmp_path)

        config = pyparam_parser.read_yaml_file(self.config_tmp_path)
        loaded_pyparams = pyparam_parser.read_params_from_config(config)

        self.assertEqual(pyparams, loaded_pyparams)
        pyparam_parser.compile_source_code(source_code, config)
Exemple #11
0
    def test_get_all_pyparams_from_source_code(self):
        source_code = pyparam_parser.read_source_code(self.sample_path /
                                                      "template1.py")
        pyparams = pyparam_parser.get_all_pyparams_from_source_code(
            source_code)

        exp_params = [
            NamedPyParam(
                "start_index",
                PyParam(1, int, scope="loop", desc="summation start index"),
            ),
            NamedPyParam("max_iters",
                         PyParam(6, int, "loop", "max number of iterations")),
        ]

        self.assertEqual(pyparams, exp_params)
Exemple #12
0
    def test_ast_assign_to_pyparam(self):
        source_code = pyparam_parser.read_source_code(self.sample_path /
                                                      "template1.py")
        root = ast.parse(source=source_code)
        pyparam_nodes = pyparam_parser.find_pyparams_assignments_nodes(root)

        pyparam = NamedPyParam.from_ast_node(pyparam_nodes[0])

        exp_param = NamedPyParam(
            "start_index",
            PyParam(1, int, scope="loop", desc="summation start index"))
        self.assertEqual(pyparam, exp_param)

        pyparam = NamedPyParam.from_ast_node(pyparam_nodes[1])
        exp_param = NamedPyParam(
            "max_iters", PyParam(6, int, "loop", "max number of iterations"))
        self.assertEqual(pyparam, exp_param)
Exemple #13
0
    def test_params_wo_annotations_in_functions_def(self):

        source_code = pyparam_parser.read_source_code(self.sample_path /
                                                      "template9.py")
        pyparams = pyparam_parser.get_all_pyparams_from_source_code(
            source_code)
        pyparam_parser.source_to_yaml_config(source_code, self.config_tmp_path)
        config = pyparam_parser.read_yaml_file(self.config_tmp_path)
        loaded_pyparams = pyparam_parser.read_params_from_config(config)

        self.assertEqual(pyparams, loaded_pyparams)
        compiled_source = pyparam_parser.compile_source_code(
            source_code, config)

        self.assertTrue(
            "some_function(x, y, param2: int=2, param3: float=3, "
            "param4: int=4, param5=5, param6=6)" in compiled_source)
        self.assertTrue("self, arg1: float=1.1, arg2=2.2" in compiled_source)
        self.assertTrue("result = some_function(0, 1, param2=12, param3=13)" in
                        compiled_source)
        self.assertTrue("  param2 = 2" in compiled_source)
        self.assertTrue("  param3: int = 3" in compiled_source)
        self.assertTrue(
            "  def nested_function2(x, y, np2: int=2)" in compiled_source)
 def test_include_source_module_decorators(self):
     source_code = pyparam_parser.read_source_code(
         self.sample_path / "fun_module_import_decorators.py")
     code = mods.include_modules(source_code, [self.sample_path])
Exemple #15
0
    def test_loading_yaml(self):
        source_code = pyparam_parser.read_source_code(self.sample_path /
                                                      "template3.py")
        pyparam_parser.source_to_yaml_config(source_code, self.config_tmp_path)

        config = pyparam_parser.read_yaml_file(Path(self.config_tmp_path))
        params = pyparam_parser.read_params_from_config(config)

        exp_params = [
            NamedPyParam(
                name="version",
                param=PyParam(value="1.0",
                              dtype=str,
                              scope="",
                              desc="model version"),
            ),
            NamedPyParam(
                name="base_num_filters",
                param=PyParam(value=4,
                              dtype=int,
                              scope="feature_extractor",
                              desc=""),
            ),
            NamedPyParam(
                name="include_root",
                param=PyParam(value=False,
                              dtype=bool,
                              scope="feature_extractor",
                              desc=""),
            ),
            NamedPyParam(
                name="regularize_depthwise",
                param=PyParam(value=False,
                              dtype=bool,
                              scope="feature_extractor",
                              desc=""),
            ),
            NamedPyParam(
                name="activation_fn_in_separable_conv",
                param=PyParam(value=False,
                              dtype=bool,
                              scope="feature_extractor",
                              desc=""),
            ),
            NamedPyParam(
                name="entry_flow_blocks",
                param=PyParam(
                    value=(1, 1, 1),
                    dtype=tuple,
                    scope="feature_extractor",
                    desc="Number of units in each bock in the entry flow.",
                ),
            ),
            NamedPyParam(
                name="middle_flow_blocks",
                param=PyParam(
                    value=(1, ),
                    dtype=tuple,
                    scope="feature_extractor",
                    desc="Number of units in the middle flow.",
                ),
            ),
        ]

        self.assertEqual(params, exp_params)
Exemple #16
0
 def test_find_pyparams_assignments(self):
     source_code = pyparam_parser.read_source_code(self.sample_path /
                                                   "template1.py")
     root = ast.parse(source=source_code)
     pyparam_nodes = pyparam_parser.find_pyparams_assignments_nodes(root)
     self.assertEqual(len(pyparam_nodes), 2)
Exemple #17
0
 def test_to_yaml(self):
     source_code = pyparam_parser.read_source_code(self.sample_path /
                                                   "template3.py")
     pyparam_parser.source_to_yaml_config(source_code, self.config_tmp_path)
Exemple #18
0
    def test_dict_dtype_pyparam(self):
        source_code = pyparam_parser.read_source_code(self.sample_path /
                                                      "template6.py")
        root = ast.parse(source=source_code)
        pyparam_nodes = pyparam_parser.find_pyparams_assignments_nodes(root)

        pyparam = NamedPyParam.from_ast_node(pyparam_nodes[0])

        exp_param = NamedPyParam(
            name="foo1_dict",
            param=PyParam(value={
                "a": 1,
                "b": 2
            },
                          dtype=dict,
                          scope="model",
                          desc="foo1"),
        )
        self.assertEqual(pyparam, exp_param)

        pyparam = NamedPyParam.from_ast_node(pyparam_nodes[1])
        exp_param = NamedPyParam(
            name="foo2_dict",
            param=PyParam(value={"a": [1, 1, 2]},
                          dtype=dict,
                          scope="model",
                          desc="foo2"),
        )

        self.assertEqual(pyparam, exp_param)

        pyparam = NamedPyParam.from_ast_node(pyparam_nodes[2])
        exp_param = NamedPyParam(
            name="foo3_dict",
            param=PyParam(
                value={
                    "a": {
                        "aa": 3,
                        "ab": [1, 3]
                    },
                    "b": [1, 2, 3],
                    "c": "test"
                },
                dtype=dict,
                scope="model",
                desc="foo2",
            ),
        )
        self.assertEqual(pyparam, exp_param)

        pyparam = NamedPyParam.from_ast_node(pyparam_nodes[3])

        exp_param = NamedPyParam(
            name="foo4_dict",
            param=PyParam(
                value=[
                    {
                        "a": {
                            "aa": 3,
                            "ab": [1, 3]
                        },
                        "b": [1, 2, 3],
                        "c": "test"
                    },
                    {
                        "A": {
                            "AA": 15.,
                            "AB": [1]
                        },
                        "B": [2, 3],
                        "C": "TEST"
                    },
                ],
                dtype=list,
                scope="model",
                desc="foo4 nested dict in list",
            ),
        )
        self.assertEqual(pyparam, exp_param)

        source_code = pyparam_parser.read_source_code(self.sample_path /
                                                      "template6.py")
        pyparam_parser.source_to_yaml_config(source_code, self.config_tmp_path)
        config = pyparam_parser.read_yaml_file(Path(self.config_tmp_path))

        pyparam_parser.compile_source_code(source_code,
                                           config,
                                           validate_version=False)