def test_nested_observations(self): """Test nested observations.""" logger.configure('./.test') env = make_env('CartPole-v1', 1) env = NestedVecObWrapper(env) env = NestedVecObWrapper(env) env = VecObsNormWrapper(env, log_prob=1.) print(env.observation_space) env.reset() assert env.t == 0 for _ in range(100): _, _, done, _ = env.step( np.array([env.action_space.sample() for _ in range(1)])) if done: env.reset() assert env.t == 100 state = env.state_dict() assert state['t'] == env.t state['t'] = 0 env.load_state_dict(state) assert env.t == 0 env.eval() env.reset() for _ in range(3): env.step(np.array([env.action_space.sample()])) assert env.t == 0 env.train() for _ in range(3): env.step(np.array([env.action_space.sample()])) assert env.t == 3 print(env.mean) print(env.std) shutil.rmtree('./.test')
def test_vec(self): """Test vec wrapper.""" logger.configure('./.test') nenv = 10 env = make_env('CartPole-v1', nenv) env = VecObsNormWrapper(env, log_prob=1.) print(env.observation_space) env.reset() assert env.t == 0 for _ in range(5): env.step( np.array([env.action_space.sample() for _ in range(nenv)])) state = env.state_dict() assert state['t'] == env.t assert np.allclose(state['mean'], env.mean) assert np.allclose(state['std'], env.std) state['t'] = 0 env.load_state_dict(state) assert env.t == 0 env.eval() env.reset() for _ in range(10): env.step( np.array([env.action_space.sample() for _ in range(nenv)])) assert env.t == 0 env.train() print(env.mean) print(env.std) shutil.rmtree('./.test')
def test_vec_logger(self): """Test vec logger.""" logger.configure('./.test') def env_fn(rank=0): env = gym.make('PongNoFrameskip-v4') env.seed(rank) return EpisodeInfo(env) def _env(rank): def _thunk(): return env_fn(rank=rank) return _thunk nenv = 4 env = SubprocVecEnv([_env(i) for i in range(nenv)]) env = VecEpisodeLogger(env) env.reset() for _ in range(5000): env.step( np.array([env.action_space.sample() for _ in range(nenv)])) state = env.state_dict() assert state['t'] == env.t state['t'] = 0 env.load_state_dict(state) assert env.t == 0 env.eval() env.reset() for _ in range(10): env.step( np.array([env.action_space.sample() for _ in range(nenv)])) assert env.t == 0 assert np.allclose(env.lens, 10) env.train() for _ in range(10): env.step( np.array([env.action_space.sample() for _ in range(nenv)])) assert env.t == 10 * nenv assert np.allclose(env.lens, 20) logger.flush() shutil.rmtree('./.test')
def train(logdir, algorithm, seed=0, eval=False, eval_period=None, save_period=None, maxt=None, maxseconds=None, hardware_poll_period=1): """Basic training loop. Args: logdir (str): The base directory for the training run. algorithm_class (Algorithm): The algorithm class to use for training. A new instance of the class will be constructed. seed (int): The initial seed of this experiment. eval (bool): Whether or not to evaluate the model throughout training. eval_period (int): The period with which the model is evaluated. save_period (int): The period with which the model is saved. maxt (int): The maximum number of timesteps to train the model. maxseconds (float): The maximum amount of time to train the model. hardware_poll_period (float): The period in seconds at which cpu/gpu stats are polled and logged. Use 'None' to disable logging. """ logger.configure(os.path.join(logdir, 'tb')) rng.seed(seed) alg = algorithm(logdir=logdir) config = gin.operative_config_str() logger.log("=================== CONFIG ===================") logger.log(config) with open(os.path.join(logdir, 'config.gin'), 'w') as f: f.write(config) time_start = time.monotonic() t = alg.load() if t == 0: cstr = config.replace('\n', ' \n') cstr = cstr.replace('#', '\\#') logger.add_text('config', cstr, 0, time.time()) if maxt and t > maxt: return if save_period: last_save = (t // save_period) * save_period if eval_period: last_eval = (t // eval_period) * eval_period if hardware_poll_period is not None and hardware_poll_period > 0: hardware_logger = HardwareLogger(delay=hardware_poll_period) else: hardware_logger = None try: while True: if maxt and t >= maxt: break if maxseconds and time.monotonic() - time_start >= maxseconds: break t = alg.step() if save_period and (t - last_save) >= save_period: alg.save() last_save = t if eval and (t - last_eval) >= eval_period: alg.evaluate() last_eval = t except KeyboardInterrupt: logger.log("Caught Ctrl-C. Saving model and exiting...") alg.save() if hardware_logger: hardware_logger.stop() logger.flush() logger.close() alg.close()