Ejemplo n.º 1
0
def test(times):
    data_saver = DataSaver("qlearning_test")
    data_saver.add_item("start_time", time.strftime("%Y-%m-%d_%H-%M-%S"))
    from src.envs.ec.ec_env import ECMA
    env = ECMA()
    RL = QLearningTable()
    data_saver.add_item("qlearning_param",RL.param)
    data_saver.add_item("episodes",times)
    for episode in range(0, times):
        env.reset()
        observation = env.get_obs()
        observation = list(np.round(np.array(observation,dtype=float).flatten(), 3))
        while True:
            action = RL.choose_action(str(observation))
            refined_action = RL.refine_action(action)
            reward, done ,_ = env.step(refined_action)
            observation_ = env.get_obs()
            observation_ = list(np.round(np.array(observation_,dtype=float).flatten(), 3))
            observation = copy.deepcopy(observation_)
            data_saver.append(str(episode),[observation,refined_action,reward])
            if done:
                if episode / 1000 > 0:
                    print(episode, time.strftime("%Y-%m-%d_%H-%M-%S"))
                break
    data_saver.add_item("end_time", time.strftime("%Y-%m-%d_%H-%M-%S"))
    data_saver.to_file()
Ejemplo n.º 2
0
def train(times=10000):
    data_saver = DataSaver("qlearning_training")
    data_saver.add_item("start_time", time.strftime("%Y-%m-%d_%H-%M-%S"))
    #创建环境
    env = create_env()
    RL = QLearningTable()
    data_saver.add_item("RL_param",RL.param)
    data_saver.add_item("episodes",times)
    for episode in range(0, times):
        env.reset()
        observation = env.get_obs()
        observation = list(np.round(np.array(observation,dtype=float).flatten(), 3))
        while True:
            action = RL.choose_action(str(observation))
            refined_action = RL.refine_action(action)
            reward, done ,_ = env.step(refined_action)
            observation_ = env.get_obs()
            observation_ = list(np.round(np.array(observation_,dtype=float).flatten(), 3))
            RL.learn(str(observation), action, reward, str(observation_),done)
            observation = copy.deepcopy(observation_)
            data_saver.append(str(episode),[observation,refined_action,reward])
            if done:
                if episode / 1000 > 0:
                    print(episode, time.strftime("%Y-%m-%d_%H-%M-%S"))
                if episode / 10000:
                    RL.save_model()
                break
    data_saver.add_item("end_time", time.strftime("%Y-%m-%d_%H-%M-%S"))
    data_saver.to_file()