예제 #1
0
    def _do_parse(self, config, id: str = ""):
        """
        Recursively parse the nested data in config source, add every item as `ConfigItem` to the resolver.

        Args:
            config: config source to parse.
            id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to
                go one level further into the nested structures.
                Use digits indexing from "0" for list or other strings for dict.
                For example: ``"xform#5"``, ``"net#channels"``. ``""`` indicates the entire ``self.config``.

        """
        if isinstance(config, (dict, list)):
            for k, v in enumerate(config) if isinstance(config, list) else config.items():
                sub_id = f"{id}{ID_SEP_KEY}{k}" if id != "" else k
                self._do_parse(config=v, id=sub_id)

        # copy every config item to make them independent and add them to the resolver
        item_conf = deepcopy(config)
        if ConfigComponent.is_instantiable(item_conf):
            self.ref_resolver.add_item(ConfigComponent(config=item_conf, id=id, locator=self.locator))
        elif ConfigExpression.is_expression(item_conf):
            self.ref_resolver.add_item(ConfigExpression(config=item_conf, id=id, globals=self.globals))
        else:
            self.ref_resolver.add_item(ConfigItem(config=item_conf, id=id))
예제 #2
0
    def test_resolve(self, configs, expected_id, output_type):
        locator = ComponentLocator()
        resolver = ReferenceResolver()
        # add items to resolver
        for k, v in configs.items():
            if ConfigComponent.is_instantiable(v):
                resolver.add_item(
                    ConfigComponent(config=v, id=k, locator=locator))
            elif ConfigExpression.is_expression(v):
                resolver.add_item(
                    ConfigExpression(config=v,
                                     id=k,
                                     globals={
                                         "monai": monai,
                                         "torch": torch
                                     }))
            else:
                resolver.add_item(ConfigItem(config=v, id=k))

        result = resolver.get_resolved_content(
            expected_id)  # the root id is `expected_id` here
        self.assertTrue(isinstance(result, output_type))

        # test lazy instantiation
        item = resolver.get_item(expected_id, resolve=True)
        config = item.get_config()
        config["_disabled_"] = False
        item.update_config(config=config)
        if isinstance(item, ConfigComponent):
            result = item.instantiate()
        else:
            result = item.get_config()
        self.assertTrue(isinstance(result, output_type))
예제 #3
0
    def add_item(self, item: ConfigItem):
        """
        Add a ``ConfigItem`` to the resolver.

        Args:
            item: a ``ConfigItem``.

        """
        id = item.get_id()
        if id in self.items:
            return
        self.items[id] = item
예제 #4
0
 def test_parse(self, config, expected_ids, output_types):
     parser = ConfigParser(config=config, globals={"monai": "monai"})
     # test lazy instantiation with original config content
     parser["transform"]["transforms"][0]["keys"] = "label1"
     trans = parser.get_parsed_content(id="transform#transforms#0")
     self.assertEqual(trans.keys[0], "label1")
     # test re-use the parsed content or not with the `lazy` option
     self.assertEqual(trans, parser.get_parsed_content(id="transform#transforms#0"))
     self.assertEqual(trans, parser.get_parsed_content(id="transform#transforms#0", lazy=True))
     self.assertNotEqual(trans, parser.get_parsed_content(id="transform#transforms#0", lazy=False))
     # test nested id
     parser["transform#transforms#0#keys"] = "label2"
     self.assertEqual(parser.get_parsed_content(id="transform#transforms#0").keys[0], "label2")
     for id, cls in zip(expected_ids, output_types):
         self.assertTrue(isinstance(parser.get_parsed_content(id), cls))
     # test root content
     root = parser.get_parsed_content(id="")
     for v, cls in zip(root.values(), [Compose, Dataset, DataLoader]):
         self.assertTrue(isinstance(v, cls))
     # test default value
     self.assertEqual(parser.get_parsed_content(id="abc", default=ConfigItem(12345, "abc")), 12345)