Exemple #1
0
def train_ppo():

    log_dir = f"model_save/"
    env = ENV_CONTINUE(istest=False)
    env = Monitor(env, log_dir)
    env = DummyVecEnv([lambda: env])
    # env = VecNormalize(env, norm_obs=True, norm_reward=True,
    #                clip_obs=10.)

    model = PPO('MlpPolicy', env, verbose=1, batch_size=2048, seed=1)
    callback = SaveOnBestTrainingRewardCallback(check_freq=480, log_dir=log_dir)
    model.learn(total_timesteps=int(1000000), callback = callback, log_interval = 480)
    model.save('model_save/ppo')
Exemple #2
0
def train_td3():

    log_dir = f"model_save/"
    env = ENV_CONTINUE(istest=False)
    env = Monitor(env, log_dir)
    env = DummyVecEnv([lambda: env])
    # env = VecNormalize(env, norm_obs=True, norm_reward=True,
    #                clip_obs=10.)

    n_actions = env.action_space.shape[-1]
    action_noise = NormalActionNoise(mean=np.zeros(n_actions),
                                     sigma=0.1 * np.ones(n_actions))
    # model = TD3("MlpPolicy", env, action_noise=action_noise, verbose=1, batch_size=2048, seed=1)

    model = TD3('MlpPolicy', env, verbose=1, batch_size=2048, seed=1)
    callback = SaveOnBestTrainingRewardCallback(check_freq=480,
                                                log_dir=log_dir)
    model.learn(total_timesteps=int(100000),
                callback=callback,
                log_interval=100)
    model.save('model_save/td3')
Exemple #3
0
def test_ppo():
    log_dir = f"model_save/best_model_ppo"
    env = ENV_CONTINUE(istest=True)
    env.render = True
    env = Monitor(env, log_dir)
    model = PPO.load(log_dir)
    plot_results(f"model_save/")
    for i in range(10):
        state = env.reset()
        day = 0
        while True:
            action = model.predict(state)
            next_state, reward, done, info = env.step(action[0])
            state = next_state
            # print("trying:",day,"reward:", reward,"now profit:",env.profit)
            day+=1
            if done:
                print('stock',i,' total profit=',env.profit,' buy hold=',env.buy_hold)
                break