Exemplo n.º 1
0
def test_truncation(stateful, state_tuple):
    """
    Test sequence truncation for TruncatedRoller with a
    batch of one environment.
    """
    def env_fn():
        return SimpleEnv(7, (5, 3), 'uint8')

    env = env_fn()
    model = SimpleModel(env.action_space.low.shape,
                        stateful=stateful,
                        state_tuple=state_tuple)
    basic_roller = BasicRoller(env, model, min_episodes=5)
    expected = basic_roller.rollouts()
    total_timesteps = sum([x.num_steps for x in expected])

    batched_env = batched_gym_env([env_fn], sync=True)
    trunc_roller = TruncatedRoller(batched_env, model,
                                   total_timesteps // 2 + 1)
    actual1 = trunc_roller.rollouts()
    assert actual1[-1].trunc_end
    actual2 = trunc_roller.rollouts()
    expected1, expected2 = _artificial_truncation(expected,
                                                  len(actual1) - 1,
                                                  actual1[-1].num_steps)
    assert len(actual2) == len(expected2) + 1
    actual2 = actual2[:-1]
    _compare_rollout_batch(actual1, expected1)
    _compare_rollout_batch(actual2, expected2)
Exemplo n.º 2
0
def main():
    with tf.Session() as sess:
        print('Creating environment...')
        env = TFBatchedEnv(sess, Pong(), 1)
        env = BatchedFrameStack(env)

        print('Creating model...')
        model = CNN(sess,
                    gym_space_distribution(env.action_space),
                    gym_space_vectorizer(env.observation_space))

        print('Creating roller...')
        roller = TruncatedRoller(env, model, 1)

        print('Initializing variables...')
        sess.run(tf.global_variables_initializer())

        if os.path.exists('params.pkl'):
            print('Loading parameters...')
            with open('params.pkl', 'rb') as in_file:
                params = pickle.load(in_file)
            for var, val in zip(tf.trainable_variables(), params):
                sess.run(tf.assign(var, val))
        else:
            print('Warning: parameter file does not exist!')

        print('Running agent...')
        viewer = SimpleImageViewer()
        while True:
            for obs in roller.rollouts()[0].step_observations:
                viewer.imshow(obs[..., -3:])
Exemplo n.º 3
0
def test_output_consistency():
    """
    Test that outputs from stepping are consistent with
    the batched model outputs.
    """
    with tf.Graph().as_default():
        with tf.Session() as sess:
            env = batched_gym_env([DummyEnv] * 16, sync=True)
            model = ReActFF(sess,
                            *gym_spaces(env),
                            input_scale=1.0,
                            input_dtype=tf.float32,
                            base=lambda x: tf.layers.dense(x, 12),
                            actor=MatMul,
                            critic=lambda x: MatMul(x, 1))
            sess.run(tf.global_variables_initializer())
            roller = TruncatedRoller(env, model, 8)
            for _ in range(10):
                rollouts = roller.rollouts()
                actor_out, critic_out = model.batch_outputs()
                info = next(model.batches(rollouts))
                actor_out, critic_out = sess.run(model.batch_outputs(),
                                                 feed_dict=info['feed_dict'])
                idxs = enumerate(
                    zip(info['rollout_idxs'], info['timestep_idxs']))
                for i, (rollout_idx, timestep_idx) in idxs:
                    outs = rollouts[rollout_idx].model_outs[timestep_idx]
                    assert np.allclose(actor_out[i], outs['action_params'][0])
                    assert np.allclose(critic_out[i], outs['values'][0])
Exemplo n.º 4
0
    def _test_truncation_case(self, stateful, state_tuple):
        """
        Test rollout truncation and continuation for a
        specific set of model parameters.
        """
        env_fn = lambda: SimpleEnv(7, (5, 3), 'uint8')
        env = env_fn()
        model = SimpleModel(env.action_space.low.shape,
                            stateful=stateful,
                            state_tuple=state_tuple)
        basic_roller = BasicRoller(env, model, min_episodes=5)
        expected = basic_roller.rollouts()
        total_timesteps = sum([x.num_steps for x in expected])

        batched_env = batched_gym_env([env_fn], sync=True)
        trunc_roller = TruncatedRoller(batched_env, model,
                                       total_timesteps // 2 + 1)
        actual1 = trunc_roller.rollouts()
        self.assertTrue(actual1[-1].trunc_end)
        actual2 = trunc_roller.rollouts()
        expected1, expected2 = _artificial_truncation(expected,
                                                      len(actual1) - 1,
                                                      actual1[-1].num_steps)
        self.assertEqual(len(actual2), len(expected2) + 1)
        actual2 = actual2[:-1]
        _compare_rollout_batch(self, actual1, expected1)
        _compare_rollout_batch(self, actual2, expected2)
Exemplo n.º 5
0
def main():
    args = arg_parser().parse_args()
    env = make_env(args)
    with tf.Session() as sess:
        model = make_model(args, sess, env)
        print('Initializing model variables...')
        sess.run(tf.global_variables_initializer())
        roller = TruncatedRoller(env, model, 128)
        total, good = 0, 0
        while True:
            r = [r for r in roller.rollouts() if not r.trunc_end]
            sess.run(model.reptile.apply_updates)
            total += len(r)
            good += len([x for x in r if x.total_reward > 0])
            print('got %f (%d out of %d)' % (good / total, good, total))
Exemplo n.º 6
0
    def _test_basic_equivalence_case(self, stateful, state_tuple):
        """
        Test BasicRoller equivalence for a specific set of
        model settings.
        """
        env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8')
        env = env_fn()
        model = SimpleModel(env.action_space.low.shape,
                            stateful=stateful,
                            state_tuple=state_tuple)
        basic_roller = BasicRoller(env, model, min_episodes=5)
        expected = basic_roller.rollouts()
        total_timesteps = sum([x.num_steps for x in expected])

        batched_env = batched_gym_env([env_fn], sync=True)
        trunc_roller = TruncatedRoller(batched_env, model, total_timesteps)
        actual = trunc_roller.rollouts()
        _compare_rollout_batch(self, actual, expected)
Exemplo n.º 7
0
def test_trunc_basic_equivalence(stateful, state_tuple):
    """
    Test that TruncatedRoller is equivalent to BasicRoller
    for batches of one environment when the episodes end
    cleanly.
    """
    env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8')
    env = env_fn()
    model = SimpleModel(env.action_space.low.shape,
                        stateful=stateful,
                        state_tuple=state_tuple)
    basic_roller = BasicRoller(env, model, min_episodes=5)
    expected = basic_roller.rollouts()
    total_timesteps = sum([x.num_steps for x in expected])

    batched_env = batched_gym_env([env_fn], sync=True)
    trunc_roller = TruncatedRoller(batched_env, model, total_timesteps)
    actual = trunc_roller.rollouts()
    _compare_rollout_batch(actual, expected)
Exemplo n.º 8
0
    def _test_batch_equivalence_case(self, stateful, state_tuple):
        """
        Test that doing things in batches is consistent,
        given the model parameters.
        """
        env_fns = [
            lambda seed=x: SimpleEnv(seed, (5, 3), 'uint8') for x in range(15)
        ]
        model = SimpleModel((5, 3), stateful=stateful, state_tuple=state_tuple)

        unbatched_rollouts = []
        for env_fn in env_fns:
            batched_env = batched_gym_env([env_fn], sync=True)
            trunc_roller = TruncatedRoller(batched_env, model, 17)
            for _ in range(3):
                unbatched_rollouts.extend(trunc_roller.rollouts())

        batched_rollouts = []
        batched_env = batched_gym_env(env_fns, num_sub_batches=3, sync=True)
        trunc_roller = TruncatedRoller(batched_env, model, 17)
        for _ in range(3):
            batched_rollouts.extend(trunc_roller.rollouts())

        _compare_rollout_batch(self,
                               unbatched_rollouts,
                               batched_rollouts,
                               ordered=False)
Exemplo n.º 9
0
def learn_pong():
    """Train an agent."""
    env = batched_gym_env([make_single_env] * NUM_WORKERS)
    try:
        agent = ActorCritic(gym_space_distribution(env.action_space),
                            gym_space_vectorizer(env.observation_space))
        with tf.Session() as sess:
            a2c = A2C(sess, agent, target_kl=TARGET_KL)
            roller = TruncatedRoller(env, agent, HORIZON)
            total_steps = 0
            rewards = []
            print("Training... Don't expect progress for ~400K steps.")
            while True:
                with agent.frozen():
                    rollouts = roller.rollouts()
                for rollout in rollouts:
                    total_steps += rollout.num_steps
                    if not rollout.trunc_end:
                        rewards.append(rollout.total_reward)
                agent.actor.extend(
                    a2c.policy_update(rollouts,
                                      POLICY_STEP,
                                      NUM_STEPS,
                                      min_leaf=MIN_LEAF,
                                      feature_frac=FEATURE_FRAC))
                agent.critic.extend(
                    a2c.value_update(rollouts,
                                     VALUE_STEP,
                                     NUM_STEPS,
                                     min_leaf=MIN_LEAF,
                                     feature_frac=FEATURE_FRAC))
                if rewards:
                    print(
                        '%d steps: mean=%f' %
                        (total_steps, sum(rewards[-10:]) / len(rewards[-10:])))
                else:
                    print('%d steps: no episodes complete yet' % total_steps)
    finally:
        env.close()
Exemplo n.º 10
0
def main():
    with tf.Session() as sess:
        print('Creating environment...')
        env = TFBatchedEnv(sess, Pong(), 8)
        env = BatchedFrameStack(env)

        print('Creating model...')
        model = CNN(sess, gym_space_distribution(env.action_space),
                    gym_space_vectorizer(env.observation_space))

        print('Creating roller...')
        roller = TruncatedRoller(env, model, 128)

        print('Creating PPO graph...')
        ppo = PPO(model)
        optimize = ppo.optimize(learning_rate=3e-4)

        print('Initializing variables...')
        sess.run(tf.global_variables_initializer())

        print('Training agent...')
        for i in count():
            rollouts = roller.rollouts()
            for rollout in rollouts:
                if not rollout.trunc_end:
                    print('reward=%f steps=%d' %
                          (rollout.total_reward, rollout.total_steps))
            total_steps = sum(r.num_steps for r in rollouts)
            ppo.run_optimize(optimize,
                             rollouts,
                             batch_size=total_steps // 4,
                             num_iter=12,
                             log_fn=print)
            if i % 5 == 0:
                print('Saving...')
                parameters = sess.run(tf.trainable_variables())
                with open('params.pkl', 'wb+') as out_file:
                    pickle.dump(parameters, out_file)
Exemplo n.º 11
0
def test_batched_env_rollouts(benchmark):
    """
    Benchmark rollouts with a batched environment and a
    regular truncated roller.
    """
    env = batched_gym_env([lambda: gym.make('Pong-v0')] * 8)
    try:
        agent = ActorCritic(gym_space_distribution(env.action_space),
                            gym_space_distribution(env.observation_space))
        agent.actor = _testing_ensemble()
        agent.critic = _testing_ensemble(num_outs=1)
        roller = TruncatedRoller(env, agent, 128)
        with agent.frozen():
            benchmark(roller.rollouts)
    finally:
        env.close()
Exemplo n.º 12
0
def test_trunc_drop_states():
    """
    Test TruncatedRoller with drop_states=True.
    """
    env_fns = [
        lambda seed=x: SimpleEnv(seed, (5, 3), 'uint8') for x in range(15)
    ]
    model = SimpleModel((5, 3), stateful=True, state_tuple=True)

    expected_rollouts = []
    batched_env = batched_gym_env(env_fns, num_sub_batches=3, sync=True)
    trunc_roller = TruncatedRoller(batched_env, model, 17)
    for _ in range(3):
        expected_rollouts.extend(trunc_roller.rollouts())
    for rollout in expected_rollouts:
        for model_out in rollout.model_outs:
            model_out['states'] = None

    actual_rollouts = []
    trunc_roller = TruncatedRoller(batched_env, model, 17, drop_states=True)
    for _ in range(3):
        actual_rollouts.extend(trunc_roller.rollouts())

    _compare_rollout_batch(actual_rollouts, expected_rollouts)
Exemplo n.º 13
0
def test_trunc_batches(stateful, state_tuple):
    """
    Test that TruncatedRoller produces the same result for
    batches as it does for individual environments.
    """
    env_fns = [
        lambda seed=x: SimpleEnv(seed, (5, 3), 'uint8') for x in range(15)
    ]
    model = SimpleModel((5, 3), stateful=stateful, state_tuple=state_tuple)

    unbatched_rollouts = []
    for env_fn in env_fns:
        batched_env = batched_gym_env([env_fn], sync=True)
        trunc_roller = TruncatedRoller(batched_env, model, 17)
        for _ in range(3):
            unbatched_rollouts.extend(trunc_roller.rollouts())

    batched_rollouts = []
    batched_env = batched_gym_env(env_fns, num_sub_batches=3, sync=True)
    trunc_roller = TruncatedRoller(batched_env, model, 17)
    for _ in range(3):
        batched_rollouts.extend(trunc_roller.rollouts())

    _compare_rollout_batch(unbatched_rollouts, batched_rollouts, ordered=False)
Exemplo n.º 14
0
def mpi_ppo_loop(ppo,
                 env,
                 horizon=128,
                 lr=0.0003,
                 num_iters=16,
                 num_batches=4,
                 reward_scale=1.0,
                 save_path=None,
                 save_interval=5,
                 load_fn=None,
                 rollout_fn=None):
    """
    Run PPO forever on an environment.

    Args:
      ppo: an anyrl PPO instance.
      env: a batched environment.
      horizon: the number of timesteps per segment.
      lr: the Adam learning rate.
      num_iters: the number of training iterations.
      num_batches: the number of mini-batches per training
        epoch.
      reward_scale: a scale to bring rewards into a
        reasonable range.
      save_path: the variable state file.
      save_interval: outer loop iterations per save.
      load_fn: a function to call to load any extra TF
        variables before syncing and training.
      rollout_fn: a function that is called with every
        batch of rollouts before the rollouts are used.
    """
    from .mpi import is_mpi_root, mpi_log

    sess = ppo.model.session

    roller = TruncatedRoller(env, ppo.model, horizon)
    optimizer = MPIOptimizer(tf.train.AdamOptimizer(learning_rate=lr),
                             -ppo.objective,
                             var_list=ppo.variables)

    mpi_log('Initializing optimizer variables...')
    sess.run([v.initializer for v in optimizer.optimizer_vars])

    if save_path and is_mpi_root():
        load_vars(sess, save_path, var_list=ppo.variables)

    if load_fn is not None:
        load_fn()

    mpi_log('Syncing parameters...')
    optimizer.sync_from_root(sess)

    mpi_log('Training...')
    for i in itertools.count():
        mpi_ppo_round(ppo,
                      optimizer,
                      roller,
                      num_iters=num_iters,
                      num_batches=num_batches,
                      reward_scale=reward_scale,
                      rollout_fn=rollout_fn)

        if save_path and i % save_interval == 0 and is_mpi_root():
            save_vars(sess, save_path, var_list=ppo.variables)

        mpi_log('done iteration %d' % i)