def test_truncation(stateful, state_tuple): """ Test sequence truncation for TruncatedRoller with a batch of one environment. """ def env_fn(): return SimpleEnv(7, (5, 3), 'uint8') env = env_fn() model = SimpleModel(env.action_space.low.shape, stateful=stateful, state_tuple=state_tuple) basic_roller = BasicRoller(env, model, min_episodes=5) expected = basic_roller.rollouts() total_timesteps = sum([x.num_steps for x in expected]) batched_env = batched_gym_env([env_fn], sync=True) trunc_roller = TruncatedRoller(batched_env, model, total_timesteps // 2 + 1) actual1 = trunc_roller.rollouts() assert actual1[-1].trunc_end actual2 = trunc_roller.rollouts() expected1, expected2 = _artificial_truncation(expected, len(actual1) - 1, actual1[-1].num_steps) assert len(actual2) == len(expected2) + 1 actual2 = actual2[:-1] _compare_rollout_batch(actual1, expected1) _compare_rollout_batch(actual2, expected2)
def main(): with tf.Session() as sess: print('Creating environment...') env = TFBatchedEnv(sess, Pong(), 1) env = BatchedFrameStack(env) print('Creating model...') model = CNN(sess, gym_space_distribution(env.action_space), gym_space_vectorizer(env.observation_space)) print('Creating roller...') roller = TruncatedRoller(env, model, 1) print('Initializing variables...') sess.run(tf.global_variables_initializer()) if os.path.exists('params.pkl'): print('Loading parameters...') with open('params.pkl', 'rb') as in_file: params = pickle.load(in_file) for var, val in zip(tf.trainable_variables(), params): sess.run(tf.assign(var, val)) else: print('Warning: parameter file does not exist!') print('Running agent...') viewer = SimpleImageViewer() while True: for obs in roller.rollouts()[0].step_observations: viewer.imshow(obs[..., -3:])
def test_output_consistency(): """ Test that outputs from stepping are consistent with the batched model outputs. """ with tf.Graph().as_default(): with tf.Session() as sess: env = batched_gym_env([DummyEnv] * 16, sync=True) model = ReActFF(sess, *gym_spaces(env), input_scale=1.0, input_dtype=tf.float32, base=lambda x: tf.layers.dense(x, 12), actor=MatMul, critic=lambda x: MatMul(x, 1)) sess.run(tf.global_variables_initializer()) roller = TruncatedRoller(env, model, 8) for _ in range(10): rollouts = roller.rollouts() actor_out, critic_out = model.batch_outputs() info = next(model.batches(rollouts)) actor_out, critic_out = sess.run(model.batch_outputs(), feed_dict=info['feed_dict']) idxs = enumerate( zip(info['rollout_idxs'], info['timestep_idxs'])) for i, (rollout_idx, timestep_idx) in idxs: outs = rollouts[rollout_idx].model_outs[timestep_idx] assert np.allclose(actor_out[i], outs['action_params'][0]) assert np.allclose(critic_out[i], outs['values'][0])
def _test_truncation_case(self, stateful, state_tuple): """ Test rollout truncation and continuation for a specific set of model parameters. """ env_fn = lambda: SimpleEnv(7, (5, 3), 'uint8') env = env_fn() model = SimpleModel(env.action_space.low.shape, stateful=stateful, state_tuple=state_tuple) basic_roller = BasicRoller(env, model, min_episodes=5) expected = basic_roller.rollouts() total_timesteps = sum([x.num_steps for x in expected]) batched_env = batched_gym_env([env_fn], sync=True) trunc_roller = TruncatedRoller(batched_env, model, total_timesteps // 2 + 1) actual1 = trunc_roller.rollouts() self.assertTrue(actual1[-1].trunc_end) actual2 = trunc_roller.rollouts() expected1, expected2 = _artificial_truncation(expected, len(actual1) - 1, actual1[-1].num_steps) self.assertEqual(len(actual2), len(expected2) + 1) actual2 = actual2[:-1] _compare_rollout_batch(self, actual1, expected1) _compare_rollout_batch(self, actual2, expected2)
def main(): args = arg_parser().parse_args() env = make_env(args) with tf.Session() as sess: model = make_model(args, sess, env) print('Initializing model variables...') sess.run(tf.global_variables_initializer()) roller = TruncatedRoller(env, model, 128) total, good = 0, 0 while True: r = [r for r in roller.rollouts() if not r.trunc_end] sess.run(model.reptile.apply_updates) total += len(r) good += len([x for x in r if x.total_reward > 0]) print('got %f (%d out of %d)' % (good / total, good, total))
def _test_basic_equivalence_case(self, stateful, state_tuple): """ Test BasicRoller equivalence for a specific set of model settings. """ env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8') env = env_fn() model = SimpleModel(env.action_space.low.shape, stateful=stateful, state_tuple=state_tuple) basic_roller = BasicRoller(env, model, min_episodes=5) expected = basic_roller.rollouts() total_timesteps = sum([x.num_steps for x in expected]) batched_env = batched_gym_env([env_fn], sync=True) trunc_roller = TruncatedRoller(batched_env, model, total_timesteps) actual = trunc_roller.rollouts() _compare_rollout_batch(self, actual, expected)
def test_trunc_basic_equivalence(stateful, state_tuple): """ Test that TruncatedRoller is equivalent to BasicRoller for batches of one environment when the episodes end cleanly. """ env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8') env = env_fn() model = SimpleModel(env.action_space.low.shape, stateful=stateful, state_tuple=state_tuple) basic_roller = BasicRoller(env, model, min_episodes=5) expected = basic_roller.rollouts() total_timesteps = sum([x.num_steps for x in expected]) batched_env = batched_gym_env([env_fn], sync=True) trunc_roller = TruncatedRoller(batched_env, model, total_timesteps) actual = trunc_roller.rollouts() _compare_rollout_batch(actual, expected)
def _test_batch_equivalence_case(self, stateful, state_tuple): """ Test that doing things in batches is consistent, given the model parameters. """ env_fns = [ lambda seed=x: SimpleEnv(seed, (5, 3), 'uint8') for x in range(15) ] model = SimpleModel((5, 3), stateful=stateful, state_tuple=state_tuple) unbatched_rollouts = [] for env_fn in env_fns: batched_env = batched_gym_env([env_fn], sync=True) trunc_roller = TruncatedRoller(batched_env, model, 17) for _ in range(3): unbatched_rollouts.extend(trunc_roller.rollouts()) batched_rollouts = [] batched_env = batched_gym_env(env_fns, num_sub_batches=3, sync=True) trunc_roller = TruncatedRoller(batched_env, model, 17) for _ in range(3): batched_rollouts.extend(trunc_roller.rollouts()) _compare_rollout_batch(self, unbatched_rollouts, batched_rollouts, ordered=False)
def learn_pong(): """Train an agent.""" env = batched_gym_env([make_single_env] * NUM_WORKERS) try: agent = ActorCritic(gym_space_distribution(env.action_space), gym_space_vectorizer(env.observation_space)) with tf.Session() as sess: a2c = A2C(sess, agent, target_kl=TARGET_KL) roller = TruncatedRoller(env, agent, HORIZON) total_steps = 0 rewards = [] print("Training... Don't expect progress for ~400K steps.") while True: with agent.frozen(): rollouts = roller.rollouts() for rollout in rollouts: total_steps += rollout.num_steps if not rollout.trunc_end: rewards.append(rollout.total_reward) agent.actor.extend( a2c.policy_update(rollouts, POLICY_STEP, NUM_STEPS, min_leaf=MIN_LEAF, feature_frac=FEATURE_FRAC)) agent.critic.extend( a2c.value_update(rollouts, VALUE_STEP, NUM_STEPS, min_leaf=MIN_LEAF, feature_frac=FEATURE_FRAC)) if rewards: print( '%d steps: mean=%f' % (total_steps, sum(rewards[-10:]) / len(rewards[-10:]))) else: print('%d steps: no episodes complete yet' % total_steps) finally: env.close()
def main(): with tf.Session() as sess: print('Creating environment...') env = TFBatchedEnv(sess, Pong(), 8) env = BatchedFrameStack(env) print('Creating model...') model = CNN(sess, gym_space_distribution(env.action_space), gym_space_vectorizer(env.observation_space)) print('Creating roller...') roller = TruncatedRoller(env, model, 128) print('Creating PPO graph...') ppo = PPO(model) optimize = ppo.optimize(learning_rate=3e-4) print('Initializing variables...') sess.run(tf.global_variables_initializer()) print('Training agent...') for i in count(): rollouts = roller.rollouts() for rollout in rollouts: if not rollout.trunc_end: print('reward=%f steps=%d' % (rollout.total_reward, rollout.total_steps)) total_steps = sum(r.num_steps for r in rollouts) ppo.run_optimize(optimize, rollouts, batch_size=total_steps // 4, num_iter=12, log_fn=print) if i % 5 == 0: print('Saving...') parameters = sess.run(tf.trainable_variables()) with open('params.pkl', 'wb+') as out_file: pickle.dump(parameters, out_file)
def test_batched_env_rollouts(benchmark): """ Benchmark rollouts with a batched environment and a regular truncated roller. """ env = batched_gym_env([lambda: gym.make('Pong-v0')] * 8) try: agent = ActorCritic(gym_space_distribution(env.action_space), gym_space_distribution(env.observation_space)) agent.actor = _testing_ensemble() agent.critic = _testing_ensemble(num_outs=1) roller = TruncatedRoller(env, agent, 128) with agent.frozen(): benchmark(roller.rollouts) finally: env.close()
def test_trunc_drop_states(): """ Test TruncatedRoller with drop_states=True. """ env_fns = [ lambda seed=x: SimpleEnv(seed, (5, 3), 'uint8') for x in range(15) ] model = SimpleModel((5, 3), stateful=True, state_tuple=True) expected_rollouts = [] batched_env = batched_gym_env(env_fns, num_sub_batches=3, sync=True) trunc_roller = TruncatedRoller(batched_env, model, 17) for _ in range(3): expected_rollouts.extend(trunc_roller.rollouts()) for rollout in expected_rollouts: for model_out in rollout.model_outs: model_out['states'] = None actual_rollouts = [] trunc_roller = TruncatedRoller(batched_env, model, 17, drop_states=True) for _ in range(3): actual_rollouts.extend(trunc_roller.rollouts()) _compare_rollout_batch(actual_rollouts, expected_rollouts)
def test_trunc_batches(stateful, state_tuple): """ Test that TruncatedRoller produces the same result for batches as it does for individual environments. """ env_fns = [ lambda seed=x: SimpleEnv(seed, (5, 3), 'uint8') for x in range(15) ] model = SimpleModel((5, 3), stateful=stateful, state_tuple=state_tuple) unbatched_rollouts = [] for env_fn in env_fns: batched_env = batched_gym_env([env_fn], sync=True) trunc_roller = TruncatedRoller(batched_env, model, 17) for _ in range(3): unbatched_rollouts.extend(trunc_roller.rollouts()) batched_rollouts = [] batched_env = batched_gym_env(env_fns, num_sub_batches=3, sync=True) trunc_roller = TruncatedRoller(batched_env, model, 17) for _ in range(3): batched_rollouts.extend(trunc_roller.rollouts()) _compare_rollout_batch(unbatched_rollouts, batched_rollouts, ordered=False)
def mpi_ppo_loop(ppo, env, horizon=128, lr=0.0003, num_iters=16, num_batches=4, reward_scale=1.0, save_path=None, save_interval=5, load_fn=None, rollout_fn=None): """ Run PPO forever on an environment. Args: ppo: an anyrl PPO instance. env: a batched environment. horizon: the number of timesteps per segment. lr: the Adam learning rate. num_iters: the number of training iterations. num_batches: the number of mini-batches per training epoch. reward_scale: a scale to bring rewards into a reasonable range. save_path: the variable state file. save_interval: outer loop iterations per save. load_fn: a function to call to load any extra TF variables before syncing and training. rollout_fn: a function that is called with every batch of rollouts before the rollouts are used. """ from .mpi import is_mpi_root, mpi_log sess = ppo.model.session roller = TruncatedRoller(env, ppo.model, horizon) optimizer = MPIOptimizer(tf.train.AdamOptimizer(learning_rate=lr), -ppo.objective, var_list=ppo.variables) mpi_log('Initializing optimizer variables...') sess.run([v.initializer for v in optimizer.optimizer_vars]) if save_path and is_mpi_root(): load_vars(sess, save_path, var_list=ppo.variables) if load_fn is not None: load_fn() mpi_log('Syncing parameters...') optimizer.sync_from_root(sess) mpi_log('Training...') for i in itertools.count(): mpi_ppo_round(ppo, optimizer, roller, num_iters=num_iters, num_batches=num_batches, reward_scale=reward_scale, rollout_fn=rollout_fn) if save_path and i % save_interval == 0 and is_mpi_root(): save_vars(sess, save_path, var_list=ppo.variables) mpi_log('done iteration %d' % i)