def test_vec_env(tmpdir):
    """Test VecNormalize Object"""
    clip_obs = 0.5
    clip_reward = 5.0

    orig_venv = DummyVecEnv([make_env])
    norm_venv = VecNormalize(orig_venv,
                             norm_obs=True,
                             norm_reward=True,
                             clip_obs=clip_obs,
                             clip_reward=clip_reward)
    _, done = norm_venv.reset(), [False]
    while not done[0]:
        actions = [norm_venv.action_space.sample()]
        obs, rew, done, _ = norm_venv.step(actions)
        assert np.max(np.abs(obs)) <= clip_obs
        assert np.max(np.abs(rew)) <= clip_reward

    path = str(tmpdir.join("vec_normalize"))
    norm_venv.save(path)
    deserialized = VecNormalize.load(path, venv=orig_venv)
    check_vec_norm_equal(norm_venv, deserialized)
Example #2
0
def train(num_timesteps, model_to_load):

    try:
        env = DummyVecEnv([dsgym])
        env = VecNormalize(env)
        policy = MlpPolicy
        lr = 3e-4 * 0.75

        model = PPO2(policy=policy,
                     env=env,
                     n_steps=2048,
                     nminibatches=32,
                     lam=0.95,
                     gamma=0.99,
                     noptepochs=10,
                     ent_coef=0.01,
                     learning_rate=linear_schedule(lr),
                     cliprange=0.2)
        if model_to_load:
            env = DummyVecEnv([dsgym])
            env = VecNormalize.load(
                model_to_load.replace(".zip", "vec_normalize.pkl"), env)
            model = model.load(model_to_load)
            model.set_env(env)
            print("Loaded model from: ", model_to_load)
            model.set_learning_rate_func(linear_schedule_start_zero(lr))
        model.learn(total_timesteps=num_timesteps)
    except KeyboardInterrupt:
        print("Saving on keyinterrupt")
        model.save("D:/openAi/ppo2save/" + time.strftime("%Y_%m_%d-%H_%M_%S"))
        # quit
        sys.exit()
    except BaseException as error:
        model.save("D:/openAi/ppo2save/" + time.strftime("%Y_%m_%d-%H_%M_%S"))
        print('An exception occurred: {}'.format(error))
        traceback.print_exception(*sys.exc_info())
        sys.exit()
    model.save("D:/openAi/ppo2save/" + time.strftime("%Y_%m_%d-%H_%M_%S"))