示例#1
0
文件: config.py 项目: zeta1999/kge
    def __init__(self, folder: Optional[str] = None, load_default=True):
        """Initialize with the default configuration"""
        if load_default:
            import kge
            from kge.misc import filename_in_module

            with open(filename_in_module(kge, "config-default.yaml"),
                      "r") as file:
                self.options: Dict[str,
                                   Any] = yaml.load(file,
                                                    Loader=yaml.SafeLoader)
        else:
            self.options = {}

        self.folder = folder  # main folder (config file, checkpoints, ...)
        self.log_folder: Optional[str] = (
            None  # None means use self.folder; used for kge.log, trace.yaml
        )
        self.log_prefix: str = None
示例#2
0
    def _import(self, module_name: str):
        """Imports the specified module configuration.

        Adds the configuration options from kge/model/<module_name>.yaml to
        the configuration. Retains existing module configurations, but verifies
        that fields and their types are correct.

        """
        import kge.model, kge.model.embedder
        from kge.misc import filename_in_module

        # load the module_name
        module_config = Config(load_default=False)
        module_config.load(
            filename_in_module([kge.model, kge.model.embedder],
                               "{}.yaml".format(module_name)),
            create=True,
        )
        if "import" in module_config.options:
            del module_config.options["import"]

        # add/verify current configuration
        for key in module_config.options.keys():
            cur_value = None
            try:
                cur_value = {key: self.get(key)}
            except KeyError:
                continue
            module_config.set_all(cur_value, create=False)

        # now update this configuration
        self.set_all(module_config.options, create=True)

        # remember the import
        imports = self.options.get("import")
        if imports is None:
            imports = module_name
        elif isinstance(imports, str):
            imports = [imports, module_name]
        else:
            imports.append(module_name)
            imports = list(dict.fromkeys(imports))
        self.options["import"] = imports