Esempio n. 1
0
def write_train_default_config():
    """
    在每次训练之前,修改 default.yaml 文件中的 checkpoint_path 和 cal_max_expectation_tasks 配置项,
    修改 qmix.yaml 中的 epsilon_finish 的配置项
    :return:
    """
    default = ModifyYAML(os.path.join(os.path.dirname(__file__), "config", "default.yaml"))
    alg_config = ModifyYAML(os.path.join(os.path.dirname(__file__), "config", "algs", "qmix.yaml"))
    default_data = default.data
    default_data["checkpoint_path"] = ""
    default_data["cal_max_expectation_tasks"] = False
    alg_config.data["epsilon_finish"] = 0.1
    default.dump()
    alg_config.dump()
Esempio n. 2
0
def write_gen_default_config(checkpoint_path):
    """
    在生成数据之前,修改 default.yaml 文件中的 checkpoint_path 和 cal_max_expectation_tasks 配置项,
    修改 qmix.yaml 中的 epsilon_finish 的配置项
    :param checkpoint_path:
    :return:
    """
    default = ModifyYAML(os.path.join(os.path.dirname(__file__), "config", "default.yaml"))
    alg_config = ModifyYAML(os.path.join(os.path.dirname(__file__), "config", "algs", "qmix.yaml"))
    default_data = default.data
    default_data["checkpoint_path"] = os.path.join(os.path.dirname(__file__), "results", "models", checkpoint_path)
    default_data["cal_max_expectation_tasks"] = True
    alg_config.data["epsilon_finish"] = 0
    default.dump()
    alg_config.dump()
Esempio n. 3
0
def cal_max_expectation_tasks(args, mac, learner, runner):
    """
    加载已经训练好的模型,来生成 state-action-reward 数据,用来评价训练好的模型
    :param args:
    :param mac:
    :param learner:
    :param runner:
    :return:
    """
    algs_modify = ModifyYAML(
        os.path.join(os.path.dirname(__file__), "config", "algs", "qmix.yaml"))
    algs_modify.data["epsilon_finish"] = 0
    algs_modify.dump()

    modify = ModifyYAML(
        os.path.join(os.path.dirname(__file__), "config", "envs", "ec.yaml"))
    global_state = []
    global_action = []
    global_reward = []
    episode = int(modify.data["gen_t_max"] /
                  modify.data["env_args"]["max_steps"])
    for i in range(episode):
        episode_batch = runner.run(test_mode=False)
        episode_data = episode_batch.data.transition_data
        global_state += get_episode_state(episode_data)
        global_action += get_episode_action(episode_data)
        global_reward += get_episode_reward(episode_data)

    expected_reward = ExpectedReward(global_state,
                                     global_reward).get_expected_reward()

    label = get_label(modify)
    file_path = os.path.join(os.path.dirname(__file__), "envs", "ec", "output",
                             "rl_" + label + ".txt")
    with open(file_path, "a") as f:
        f.write(str(expected_reward) + "\n")