示例#1
0
        def test_nested_observations(self):
            """Test nested observations."""
            logger.configure('./.test')
            env = make_env('CartPole-v1', 1)
            env = NestedVecObWrapper(env)
            env = NestedVecObWrapper(env)
            env = VecObsNormWrapper(env, log_prob=1.)
            print(env.observation_space)
            env.reset()
            assert env.t == 0
            for _ in range(100):
                _, _, done, _ = env.step(
                    np.array([env.action_space.sample() for _ in range(1)]))
                if done:
                    env.reset()
            assert env.t == 100
            state = env.state_dict()
            assert state['t'] == env.t
            state['t'] = 0
            env.load_state_dict(state)
            assert env.t == 0

            env.eval()
            env.reset()
            for _ in range(3):
                env.step(np.array([env.action_space.sample()]))
            assert env.t == 0
            env.train()
            for _ in range(3):
                env.step(np.array([env.action_space.sample()]))
            assert env.t == 3
            print(env.mean)
            print(env.std)
            shutil.rmtree('./.test')
示例#2
0
        def test_vec(self):
            """Test vec wrapper."""
            logger.configure('./.test')
            nenv = 10
            env = make_env('CartPole-v1', nenv)
            env = VecObsNormWrapper(env, log_prob=1.)
            print(env.observation_space)
            env.reset()
            assert env.t == 0
            for _ in range(5):
                env.step(
                    np.array([env.action_space.sample() for _ in range(nenv)]))
            state = env.state_dict()
            assert state['t'] == env.t
            assert np.allclose(state['mean'], env.mean)
            assert np.allclose(state['std'], env.std)
            state['t'] = 0
            env.load_state_dict(state)
            assert env.t == 0

            env.eval()
            env.reset()
            for _ in range(10):
                env.step(
                    np.array([env.action_space.sample() for _ in range(nenv)]))
            assert env.t == 0
            env.train()
            print(env.mean)
            print(env.std)
            shutil.rmtree('./.test')
示例#3
0
        def test_vec_logger(self):
            """Test vec logger."""
            logger.configure('./.test')

            def env_fn(rank=0):
                env = gym.make('PongNoFrameskip-v4')
                env.seed(rank)
                return EpisodeInfo(env)

            def _env(rank):
                def _thunk():
                    return env_fn(rank=rank)

                return _thunk

            nenv = 4
            env = SubprocVecEnv([_env(i) for i in range(nenv)])
            env = VecEpisodeLogger(env)
            env.reset()
            for _ in range(5000):
                env.step(
                    np.array([env.action_space.sample() for _ in range(nenv)]))
            state = env.state_dict()
            assert state['t'] == env.t
            state['t'] = 0
            env.load_state_dict(state)
            assert env.t == 0

            env.eval()
            env.reset()
            for _ in range(10):
                env.step(
                    np.array([env.action_space.sample() for _ in range(nenv)]))
            assert env.t == 0
            assert np.allclose(env.lens, 10)
            env.train()
            for _ in range(10):
                env.step(
                    np.array([env.action_space.sample() for _ in range(nenv)]))
            assert env.t == 10 * nenv
            assert np.allclose(env.lens, 20)
            logger.flush()
            shutil.rmtree('./.test')
示例#4
0
def train(logdir,
          algorithm,
          seed=0,
          eval=False,
          eval_period=None,
          save_period=None,
          maxt=None,
          maxseconds=None,
          hardware_poll_period=1):
    """Basic training loop.

    Args:
        logdir (str):
            The base directory for the training run.
        algorithm_class (Algorithm):
            The algorithm class to use for training. A new instance of the class
            will be constructed.
        seed (int):
            The initial seed of this experiment.
        eval (bool):
            Whether or not to evaluate the model throughout training.
        eval_period (int):
            The period with which the model is evaluated.
        save_period (int):
            The period with which the model is saved.
        maxt (int):
            The maximum number of timesteps to train the model.
        maxseconds (float):
            The maximum amount of time to train the model.
        hardware_poll_period (float):
            The period in seconds at which cpu/gpu stats are polled and logged.
            Use 'None' to disable logging.
    """

    logger.configure(os.path.join(logdir, 'tb'))
    rng.seed(seed)
    alg = algorithm(logdir=logdir)
    config = gin.operative_config_str()
    logger.log("=================== CONFIG ===================")
    logger.log(config)
    with open(os.path.join(logdir, 'config.gin'), 'w') as f:
        f.write(config)
    time_start = time.monotonic()
    t = alg.load()
    if t == 0:
        cstr = config.replace('\n', '  \n')
        cstr = cstr.replace('#', '\\#')
        logger.add_text('config', cstr, 0, time.time())
    if maxt and t > maxt:
        return
    if save_period:
        last_save = (t // save_period) * save_period
    if eval_period:
        last_eval = (t // eval_period) * eval_period

    if hardware_poll_period is not None and hardware_poll_period > 0:
        hardware_logger = HardwareLogger(delay=hardware_poll_period)
    else:
        hardware_logger = None
    try:
        while True:
            if maxt and t >= maxt:
                break
            if maxseconds and time.monotonic() - time_start >= maxseconds:
                break
            t = alg.step()
            if save_period and (t - last_save) >= save_period:
                alg.save()
                last_save = t
            if eval and (t - last_eval) >= eval_period:
                alg.evaluate()
                last_eval = t
    except KeyboardInterrupt:
        logger.log("Caught Ctrl-C. Saving model and exiting...")
    alg.save()
    if hardware_logger:
        hardware_logger.stop()
    logger.flush()
    logger.close()
    alg.close()