예제 #1
0
파일: configuration.py 프로젝트: tswsxk/XKT
    def __init__(self, params_json=None, **kwargs):
        """
        Configuration File, including categories:

        * directory setting
        * optimizer setting
        * training parameters
        * equipment
        * parameters saving setting
        * user parameters

        Parameters
        ----------
        params_json: str
            The path to configuration file which is in json format
        kwargs:
            Parameters to be reset.
        """
        super(Configuration, self).__init__(logger=config_logging(
            logger=self.model_name, console_log_level=LogLevel.INFO))

        params = self.class_var
        if params_json:
            params.update(self.load_cfg(params_json=params_json))
        params.update(**kwargs)

        for key in params:
            if key.endswith("_params") and key + "_update" in params:
                params[key].update(params[key + "_update"])

        # path_override_check
        path_check_list = [
            "dataset", "root_data_dir", "workspace", "root_model_dir",
            "model_dir"
        ]
        _overridden = {}
        for path_check in path_check_list:
            if kwargs.get(path_check) is None or kwargs[path_check] == getattr(
                    self, "%s" % path_check):
                _overridden[path_check] = False
            else:
                _overridden[path_check] = True

        for param, value in params.items():
            setattr(self, "%s" % param, value)

        def is_overridden(varname):
            return _overridden["%s" % varname]

        # set dataset
        if is_overridden("dataset") and not is_overridden("root_data_dir"):
            kwargs["root_data_dir"] = path_append("$root", "data", "$dataset")
        # set workspace
        if (is_overridden("workspace") or is_overridden("root_model_dir")
            ) and not is_overridden("model_dir"):
            kwargs["model_dir"] = path_append("$root_model_dir", "workspace")

        # rebuild relevant directory or file path according to the kwargs
        _dirs = [
            "workspace", "root_data_dir", "data_dir", "root_model_dir",
            "model_dir"
        ]
        for _dir in _dirs:
            exp = var2exp(kwargs.get(_dir, getattr(self, _dir)),
                          env_wrap=lambda x: "self.%s" % x)
            setattr(self, _dir, eval(exp))

        _vars = ["ctx"]
        for _var in _vars:
            if _var in kwargs:
                try:
                    setattr(self, _var, eval_var(kwargs[_var]))
                except TypeError:
                    pass

        self.validation_result_file = path_append(self.model_dir,
                                                  "result.json",
                                                  to_str=True)
        self.cfg_path = path_append(self.model_dir,
                                    "configuration.json",
                                    to_str=True)
예제 #2
0
    def __init__(self, params_json=None, **kwargs):
        """
        Configuration File, including categories:

        * directory setting
        * optimizer setting
        * training parameters
        * equipment
        * parameters saving setting
        * user parameters

        Parameters
        ----------
        params_json: str
            The path to configuration file which is in json format
        kwargs:
            Parameters to be reset.
        """
        super(Configuration, self).__init__(logger=config_logging(
            logger=self.model_name, console_log_level=LogLevel.INFO))

        params = self.class_var
        if params_json:
            params.update(self.load_cfg(params_json=params_json))
        params.update(**kwargs)

        for param, value in params.items():
            setattr(self, "%s" % param, value)

        # set dataset
        if kwargs.get("dataset") and not kwargs.get("root_data_dir"):
            kwargs["root_data_dir"] = "$root/data/$dataset"
        # set workspace
        if (kwargs.get("workspace") or
                kwargs.get("root_model_dir")) and not kwargs.get("model_dir"):
            kwargs["model_dir"] = "$root_model_dir/$workspace"

        # rebuild relevant directory or file path according to the kwargs
        _dirs = [
            "workspace", "root_data_dir", "data_dir", "root_model_dir",
            "model_dir"
        ]
        for _dir in _dirs:
            exp = var2exp(kwargs.get(_dir, getattr(self, _dir)),
                          env_wrap=lambda x: "self.%s" % x)
            setattr(self, _dir, eval(exp))

        _vars = ["ctx"]
        for _var in _vars:
            if _var in kwargs:
                try:
                    setattr(self, _var, eval_var(kwargs[_var]))
                except TypeError:
                    pass

        self.validation_result_file = path_append(self.model_dir,
                                                  "result.json",
                                                  to_str=True)
        self.cfg_path = path_append(self.model_dir,
                                    "configuration.json",
                                    to_str=True)
예제 #3
0
파일: configuration.py 프로젝트: tswsxk/XKT
class Configuration(parser.Configuration):
    # 目录配置
    model_name = str(pathlib.Path(__file__).parents[1].name)

    root = pathlib.Path(__file__).parents[3]
    # root = "./"
    dataset = ""
    timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
    workspace = ""

    root_data_dir = path_append("$root", "data",
                                "$dataset") if dataset else path_append(
                                    "$root", "data")
    data_dir = path_append("$root_data_dir", "data")
    root_model_dir = path_append("$root_data_dir", "model", "$model_name")
    model_dir = path_append("$root_model_dir",
                            "$workspace") if workspace else root_model_dir
    cfg_path = path_append("$model_dir", "configuration.json")

    root = str(root)
    root_data_dir = str(root_data_dir)
    root_model_dir = str(root_model_dir)

    # 训练参数设置
    begin_epoch = 0
    end_epoch = 20
    batch_size = 16
    save_epoch = 1

    # 优化器设置
    # optimizer, optimizer_params = get_optimizer_cfg(name="base")
    optimizer = "adam"
    optimizer_params = {
        "learning_rate": 1e-3,
    }
    lr_params = None
    # {
    #     "learning_rate": 10e-3,
    #     "step": 100,
    #     "max_update_steps": get_update_steps(
    #         update_epoch=10,
    #         batches_per_epoch=1000,
    #     ),
    # }

    # 更新保存参数,一般需要保持一致
    train_select = _select
    save_select = train_select

    # 运行设备
    ctx = cpu(0)

    # 工具包参数
    toolbox_params = {}

    # 用户变量
    num_buckets = 100
    # 超参数
    # 网络超参数
    hyper_params = {}
    # 网络初始化参数
    init_params = {}
    # 损失函数超参数
    loss_params = {}
    # 说明
    caption = ""

    def __init__(self, params_json=None, **kwargs):
        """
        Configuration File, including categories:

        * directory setting
        * optimizer setting
        * training parameters
        * equipment
        * parameters saving setting
        * user parameters

        Parameters
        ----------
        params_json: str
            The path to configuration file which is in json format
        kwargs:
            Parameters to be reset.
        """
        super(Configuration, self).__init__(logger=config_logging(
            logger=self.model_name, console_log_level=LogLevel.INFO))

        params = self.class_var
        if params_json:
            params.update(self.load_cfg(params_json=params_json))
        params.update(**kwargs)

        for key in params:
            if key.endswith("_params") and key + "_update" in params:
                params[key].update(params[key + "_update"])

        # path_override_check
        path_check_list = [
            "dataset", "root_data_dir", "workspace", "root_model_dir",
            "model_dir"
        ]
        _overridden = {}
        for path_check in path_check_list:
            if kwargs.get(path_check) is None or kwargs[path_check] == getattr(
                    self, "%s" % path_check):
                _overridden[path_check] = False
            else:
                _overridden[path_check] = True

        for param, value in params.items():
            setattr(self, "%s" % param, value)

        def is_overridden(varname):
            return _overridden["%s" % varname]

        # set dataset
        if is_overridden("dataset") and not is_overridden("root_data_dir"):
            kwargs["root_data_dir"] = path_append("$root", "data", "$dataset")
        # set workspace
        if (is_overridden("workspace") or is_overridden("root_model_dir")
            ) and not is_overridden("model_dir"):
            kwargs["model_dir"] = path_append("$root_model_dir", "workspace")

        # rebuild relevant directory or file path according to the kwargs
        _dirs = [
            "workspace", "root_data_dir", "data_dir", "root_model_dir",
            "model_dir"
        ]
        for _dir in _dirs:
            exp = var2exp(kwargs.get(_dir, getattr(self, _dir)),
                          env_wrap=lambda x: "self.%s" % x)
            setattr(self, _dir, eval(exp))

        _vars = ["ctx"]
        for _var in _vars:
            if _var in kwargs:
                try:
                    setattr(self, _var, eval_var(kwargs[_var]))
                except TypeError:
                    pass

        self.validation_result_file = path_append(self.model_dir,
                                                  "result.json",
                                                  to_str=True)
        self.cfg_path = path_append(self.model_dir,
                                    "configuration.json",
                                    to_str=True)

    def dump(self, cfg_path=None, override=False):
        cfg_path = self.cfg_path if cfg_path is None else cfg_path
        super(Configuration, self).dump(cfg_path, override)

    @staticmethod
    def load(cfg_path, **kwargs):
        return Configuration(Configuration.load_cfg(cfg_path, **kwargs))

    def var2val(self, var):
        return eval(var2exp(var, env_wrap=lambda x: "self.%s" % x))