示例#1
0
    def _test_batch_equivalence_case(self, stateful, state_tuple):
        """
        Test that doing things in batches is consistent,
        given the model parameters.
        """
        env_fns = [
            lambda seed=x: SimpleEnv(seed, (5, 3), 'uint8') for x in range(15)
        ]
        model = SimpleModel((5, 3), stateful=stateful, state_tuple=state_tuple)

        unbatched_rollouts = []
        for env_fn in env_fns:
            batched_env = batched_gym_env([env_fn], sync=True)
            trunc_roller = TruncatedRoller(batched_env, model, 17)
            for _ in range(3):
                unbatched_rollouts.extend(trunc_roller.rollouts())

        batched_rollouts = []
        batched_env = batched_gym_env(env_fns, num_sub_batches=3, sync=True)
        trunc_roller = TruncatedRoller(batched_env, model, 17)
        for _ in range(3):
            batched_rollouts.extend(trunc_roller.rollouts())

        _compare_rollout_batch(self,
                               unbatched_rollouts,
                               batched_rollouts,
                               ordered=False)
示例#2
0
def test_dummy_equiv_dtype(dtype):
    """
    Test that batched_gym_env() gives something
    equivalent to a synchronous environment.
    """
    def make_fn(seed):
        """
        Get an environment constructor with a seed.
        """
        return lambda: SimpleEnv(seed, SHAPE, dtype)

    fns = [make_fn(i) for i in range(SUB_BATCH_SIZE * NUM_SUB_BATCHES)]
    real = batched_gym_env(fns, num_sub_batches=NUM_SUB_BATCHES)
    dummy = batched_gym_env(fns, num_sub_batches=NUM_SUB_BATCHES, sync=True)
    try:
        _assert_resets_equal(dummy, real)
        np.random.seed(1337)
        for _ in range(NUM_STEPS):
            joint_shape = (SUB_BATCH_SIZE, ) + SHAPE
            actions = np.array(np.random.randint(0, 0x100, size=joint_shape),
                               dtype=dtype)
            _assert_steps_equal(actions, dummy, real)
    finally:
        dummy.close()
        real.close()
示例#3
0
def test_batched_stack(concat):
    """
    Test that BatchedFrameStack is equivalent to a regular
    batched FrameStackEnv.
    """
    envs = [
        lambda idx=i: SimpleEnv(idx + 2, (3, 2, 5), 'float32')
        for i in range(6)
    ]
    env1 = BatchedFrameStack(batched_gym_env(envs,
                                             num_sub_batches=3,
                                             sync=True),
                             concat=concat)
    env2 = batched_gym_env(
        [lambda env=e: FrameStackEnv(env(), concat=concat) for e in envs],
        num_sub_batches=3,
        sync=True)
    for j in range(50):
        for i in range(3):
            if j == 0 or (j + i) % 17 == 0:
                env1.reset_start(sub_batch=i)
                env2.reset_start(sub_batch=i)
                obs1 = env1.reset_wait(sub_batch=i)
                obs2 = env2.reset_wait(sub_batch=i)
                assert np.allclose(obs1, obs2)
            actions = [env1.action_space.sample() for _ in range(2)]
            env1.step_start(actions, sub_batch=i)
            env2.step_start(actions, sub_batch=i)
            obs1, rews1, dones1, _ = env1.step_wait(sub_batch=i)
            obs2, rews2, dones2, _ = env2.step_wait(sub_batch=i)
            assert np.allclose(obs1, obs2)
            assert np.array(rews1 == rews2).all()
            assert np.array(dones1 == dones2).all()
示例#4
0
def test_async_creation_exit():
    """
    Test that an exception is forwarded when the
    environment constructor exits.
    """
    try:
        batched_gym_env([lambda: sys.exit(1)] * 4)
    except RuntimeError:
        return
    pytest.fail('should have gotten exception')
示例#5
0
def test_async_creation_exception():
    """
    Test that an exception is forwarded when the
    environment constructor fails.
    """
    try:

        def raiser():
            raise ValueError('hello world')

        batched_gym_env([raiser] * 4)
    except RuntimeError:
        return
    pytest.fail('should have gotten exception')
示例#6
0
def test_truncation(stateful, state_tuple):
    """
    Test sequence truncation for TruncatedRoller with a
    batch of one environment.
    """
    def env_fn():
        return SimpleEnv(7, (5, 3), 'uint8')

    env = env_fn()
    model = SimpleModel(env.action_space.low.shape,
                        stateful=stateful,
                        state_tuple=state_tuple)
    basic_roller = BasicRoller(env, model, min_episodes=5)
    expected = basic_roller.rollouts()
    total_timesteps = sum([x.num_steps for x in expected])

    batched_env = batched_gym_env([env_fn], sync=True)
    trunc_roller = TruncatedRoller(batched_env, model,
                                   total_timesteps // 2 + 1)
    actual1 = trunc_roller.rollouts()
    assert actual1[-1].trunc_end
    actual2 = trunc_roller.rollouts()
    expected1, expected2 = _artificial_truncation(expected,
                                                  len(actual1) - 1,
                                                  actual1[-1].num_steps)
    assert len(actual2) == len(expected2) + 1
    actual2 = actual2[:-1]
    _compare_rollout_batch(actual1, expected1)
    _compare_rollout_batch(actual2, expected2)
示例#7
0
    def _test_truncation_case(self, stateful, state_tuple):
        """
        Test rollout truncation and continuation for a
        specific set of model parameters.
        """
        env_fn = lambda: SimpleEnv(7, (5, 3), 'uint8')
        env = env_fn()
        model = SimpleModel(env.action_space.low.shape,
                            stateful=stateful,
                            state_tuple=state_tuple)
        basic_roller = BasicRoller(env, model, min_episodes=5)
        expected = basic_roller.rollouts()
        total_timesteps = sum([x.num_steps for x in expected])

        batched_env = batched_gym_env([env_fn], sync=True)
        trunc_roller = TruncatedRoller(batched_env, model,
                                       total_timesteps // 2 + 1)
        actual1 = trunc_roller.rollouts()
        self.assertTrue(actual1[-1].trunc_end)
        actual2 = trunc_roller.rollouts()
        expected1, expected2 = _artificial_truncation(expected,
                                                      len(actual1) - 1,
                                                      actual1[-1].num_steps)
        self.assertEqual(len(actual2), len(expected2) + 1)
        actual2 = actual2[:-1]
        _compare_rollout_batch(self, actual1, expected1)
        _compare_rollout_batch(self, actual2, expected2)
示例#8
0
def test_output_consistency():
    """
    Test that outputs from stepping are consistent with
    the batched model outputs.
    """
    with tf.Graph().as_default():
        with tf.Session() as sess:
            env = batched_gym_env([DummyEnv] * 16, sync=True)
            model = ReActFF(sess,
                            *gym_spaces(env),
                            input_scale=1.0,
                            input_dtype=tf.float32,
                            base=lambda x: tf.layers.dense(x, 12),
                            actor=MatMul,
                            critic=lambda x: MatMul(x, 1))
            sess.run(tf.global_variables_initializer())
            roller = TruncatedRoller(env, model, 8)
            for _ in range(10):
                rollouts = roller.rollouts()
                actor_out, critic_out = model.batch_outputs()
                info = next(model.batches(rollouts))
                actor_out, critic_out = sess.run(model.batch_outputs(),
                                                 feed_dict=info['feed_dict'])
                idxs = enumerate(
                    zip(info['rollout_idxs'], info['timestep_idxs']))
                for i, (rollout_idx, timestep_idx) in idxs:
                    outs = rollouts[rollout_idx].model_outs[timestep_idx]
                    assert np.allclose(actor_out[i], outs['action_params'][0])
                    assert np.allclose(critic_out[i], outs['values'][0])
示例#9
0
    def _test_batch_equivalence_case(self, stateful, state_tuple,
                                     **roller_kwargs):
        """
        Test BasicRoller equivalence when using a batch of
        environments.
        """
        env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8')
        model = SimpleModel((4, 5), stateful=stateful, state_tuple=state_tuple)

        batched_env = batched_gym_env([env_fn] * 21,
                                      num_sub_batches=7,
                                      sync=True)
        ep_roller = EpisodeRoller(batched_env, model, **roller_kwargs)
        actual = ep_roller.rollouts()

        total_steps = sum([r.num_steps for r in actual])
        self.assertTrue(len(actual) >= ep_roller.min_episodes)
        self.assertTrue(total_steps >= ep_roller.min_steps)

        if 'min_steps' not in roller_kwargs:
            num_eps = ep_roller.min_episodes + batched_env.num_envs - 1
            self.assertTrue(len(actual) == num_eps)

        basic_roller = BasicRoller(env_fn(), model, min_episodes=len(actual))
        expected = basic_roller.rollouts()

        _compare_rollout_batch(self, actual, expected)
示例#10
0
def test_ep_batches(stateful, state_tuple, limits):
    """
    Test that EpisodeRoller is equivalent to a
    BasicRoller when run on a batch of envs.
    """
    def env_fn():
        return SimpleEnv(3, (4, 5), 'uint8')

    model = SimpleModel((4, 5), stateful=stateful, state_tuple=state_tuple)

    batched_env = batched_gym_env([env_fn] * 21, num_sub_batches=7, sync=True)
    ep_roller = EpisodeRoller(batched_env, model, **limits)
    actual = ep_roller.rollouts()

    total_steps = sum([r.num_steps for r in actual])
    assert len(actual) >= ep_roller.min_episodes
    assert total_steps >= ep_roller.min_steps

    if 'min_steps' not in limits:
        num_eps = ep_roller.min_episodes + batched_env.num_envs - 1
        assert len(actual) == num_eps

    basic_roller = BasicRoller(env_fn(), model, min_episodes=len(actual))
    expected = basic_roller.rollouts()

    _compare_rollout_batch(actual, expected)
示例#11
0
def test_mixed_batch():
    """
    Test a batch with a bunch of different
    environments.
    """
    env_fns = [
        lambda s=seed: SimpleEnv(s, (1, 2, 3), 'float32')
        for seed in [3, 3, 3, 3, 3, 3]
    ]  #[5, 8, 1, 9, 3, 2]]
    make_agent = lambda: SimpleModel((1, 2, 3), stateful=True)
    for num_sub in [1, 2, 3]:
        batched_player = BatchedPlayer(
            batched_gym_env(env_fns, num_sub_batches=num_sub), make_agent(), 3)
        expected_eps = []
        for player in [
                BasicPlayer(env_fn(), make_agent(), 3) for env_fn in env_fns
        ]:
            transes = [t for _ in range(50) for t in player.play()]
            expected_eps.extend(_separate_episodes(transes))
        actual_transes = [t for _ in range(50) for t in batched_player.play()]
        actual_eps = _separate_episodes(actual_transes)
        assert len(expected_eps) == len(actual_eps)
        for episode in expected_eps:
            found = False
            for i, actual in enumerate(actual_eps):
                if _episodes_equivalent(episode, actual):
                    del actual_eps[i]
                    found = True
                    break
            assert found
示例#12
0
def main():
    """
    Entry-point for the program.
    """
    args = _parse_args()
    env = batched_gym_env([partial(make_single_env, args.game)] * args.workers)

    # Using BatchedFrameStack with concat=False is more
    # memory efficient than other stacking options.
    env = BatchedFrameStack(env, num_images=4, concat=False)

    with tf.Session() as sess:

        def make_net(name):
            return NatureQNetwork(sess,
                                  env.action_space.n,
                                  gym_space_vectorizer(env.observation_space),
                                  name,
                                  dueling=True)

        dqn = DQN(make_net('online'), make_net('target'))
        player = BatchedPlayer(env,
                               EpsGreedyQNetwork(dqn.online_net, args.epsilon))
        optimize = dqn.optimize(learning_rate=args.lr)

        sess.run(tf.global_variables_initializer())

        reward_hist = []
        total_steps = 0

        def _handle_ep(steps, rew):
            nonlocal total_steps
            total_steps += steps
            reward_hist.append(rew)
            if len(reward_hist) == REWARD_HISTORY:
                print('%d steps: mean=%f' %
                      (total_steps, sum(reward_hist) / len(reward_hist)))
                reward_hist.clear()

        dqn.train(num_steps=int(1e7),
                  player=player,
                  replay_buffer=UniformReplayBuffer(args.buffer_size),
                  optimize_op=optimize,
                  target_interval=args.target_interval,
                  batch_size=args.batch_size,
                  min_buffer_size=args.min_buffer_size,
                  handle_ep=_handle_ep)

    env.close()
示例#13
0
def test_single_batch():
    """
    Test BatchedPlayer when the batch size is 1.
    """
    make_env = lambda: SimpleEnv(9, (1, 2, 3), 'float32')
    make_agent = lambda: SimpleModel((1, 2, 3), stateful=True)
    basic_player = BasicPlayer(make_env(), make_agent(), 3)
    batched_player = BatchedPlayer(batched_gym_env([make_env]), make_agent(),
                                   3)
    for _ in range(50):
        transes1 = basic_player.play()
        transes2 = batched_player.play()
        assert len(transes1) == len(transes2)
        for trans1, trans2 in zip(transes1, transes2):
            assert _transitions_equal(trans1, trans2)
示例#14
0
def test_trunc_batches(stateful, state_tuple):
    """
    Test that TruncatedRoller produces the same result for
    batches as it does for individual environments.
    """
    env_fns = [
        lambda seed=x: SimpleEnv(seed, (5, 3), 'uint8') for x in range(15)
    ]
    model = SimpleModel((5, 3), stateful=stateful, state_tuple=state_tuple)

    unbatched_rollouts = []
    for env_fn in env_fns:
        batched_env = batched_gym_env([env_fn], sync=True)
        trunc_roller = TruncatedRoller(batched_env, model, 17)
        for _ in range(3):
            unbatched_rollouts.extend(trunc_roller.rollouts())

    batched_rollouts = []
    batched_env = batched_gym_env(env_fns, num_sub_batches=3, sync=True)
    trunc_roller = TruncatedRoller(batched_env, model, 17)
    for _ in range(3):
        batched_rollouts.extend(trunc_roller.rollouts())

    _compare_rollout_batch(unbatched_rollouts, batched_rollouts, ordered=False)
示例#15
0
def make_env(args):
    maze_data = ("A.......\n" +
                 "wwwwwww.\n" +
                 "wwx.www.\n" +
                 "www.www.\n" +
                 "www.www.\n" +
                 "www.www.\n" +
                 "www.www.\n" +
                 "........")
    maze = parse_2d_maze(maze_data)

    def _make_env():
        return TimeLimit(HorizonEnv(maze, sparse_rew=True, horizon=2),
                         max_episode_steps=args.max_timesteps)

    return batched_gym_env([_make_env] * args.num_envs, sync=True)
示例#16
0
def test_batched_env_rollouts(benchmark):
    """
    Benchmark rollouts with a batched environment and a
    regular truncated roller.
    """
    env = batched_gym_env([lambda: gym.make('Pong-v0')] * 8)
    try:
        agent = ActorCritic(gym_space_distribution(env.action_space),
                            gym_space_distribution(env.observation_space))
        agent.actor = _testing_ensemble()
        agent.critic = _testing_ensemble(num_outs=1)
        roller = TruncatedRoller(env, agent, 128)
        with agent.frozen():
            benchmark(roller.rollouts)
    finally:
        env.close()
示例#17
0
def main():
    """Run DQN until the environment throws an exception."""
    env_fns, env_names = create_envs()
    env = BatchedFrameStack(batched_gym_env(env_fns),
                            num_images=4,
                            concat=False)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # pylint: disable=E1101
    with tf.Session(config=config) as sess:
        dqn = DQN(*rainbow_models(sess,
                                  env.action_space.n,
                                  gym_space_vectorizer(env.observation_space),
                                  min_val=-200,
                                  max_val=200))
        player = NStepPlayer(BatchedPlayer(env, dqn.online_net), 3)
        optimize = dqn.optimize(learning_rate=1e-4)  # Use ADAM
        sess.run(tf.global_variables_initializer())

        reward_hist = []
        total_steps = 0

        def _handle_ep(steps, rew, env_rewards):
            nonlocal total_steps
            total_steps += steps
            reward_hist.append(rew)
            if total_steps % 1 == 0:
                print('%d episodes, %d steps: mean of last 100 episodes=%f' %
                      (len(reward_hist), total_steps,
                       sum(reward_hist[-100:]) / len(reward_hist[-100:])))

        dqn.train(
            num_steps=
            2000000000,  # Make sure an exception arrives before we stop.
            player=player,
            replay_buffer=PrioritizedReplayBuffer(500000,
                                                  0.5,
                                                  0.4,
                                                  epsilon=0.1),
            optimize_op=optimize,
            train_interval=1,
            target_interval=8192,
            batch_size=32,
            min_buffer_size=20000,
            handle_ep=_handle_ep,
            num_envs=len(env_fns),
            save_interval=10,
        )
示例#18
0
def test_ep_basic_equivalence(stateful, state_tuple, limits):
    """
    Test that EpisodeRoller is equivalent to a
    BasicRoller when run on a single environment.
    """
    env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8')
    env = env_fn()
    model = SimpleModel(env.action_space.low.shape,
                        stateful=stateful,
                        state_tuple=state_tuple)
    basic_roller = BasicRoller(env, model, **limits)
    expected = basic_roller.rollouts()

    batched_env = batched_gym_env([env_fn], sync=True)
    ep_roller = EpisodeRoller(batched_env, model, **limits)
    actual = ep_roller.rollouts()
    _compare_rollout_batch(actual, expected)
示例#19
0
    def _test_basic_equivalence_case(self, stateful, state_tuple):
        """
        Test BasicRoller equivalence for a specific set of
        model settings.
        """
        env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8')
        env = env_fn()
        model = SimpleModel(env.action_space.low.shape,
                            stateful=stateful,
                            state_tuple=state_tuple)
        basic_roller = BasicRoller(env, model, min_episodes=5)
        expected = basic_roller.rollouts()
        total_timesteps = sum([x.num_steps for x in expected])

        batched_env = batched_gym_env([env_fn], sync=True)
        trunc_roller = TruncatedRoller(batched_env, model, total_timesteps)
        actual = trunc_roller.rollouts()
        _compare_rollout_batch(self, actual, expected)
示例#20
0
    def _test_basic_equivalence_case(self, stateful, state_tuple,
                                     **roller_kwargs):
        """
        Test BasicRoller equivalence for a single env in a
        specific case.
        """
        env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8')
        env = env_fn()
        model = SimpleModel(env.action_space.low.shape,
                            stateful=stateful,
                            state_tuple=state_tuple)
        basic_roller = BasicRoller(env, model, **roller_kwargs)
        expected = basic_roller.rollouts()

        batched_env = batched_gym_env([env_fn], sync=True)
        ep_roller = EpisodeRoller(batched_env, model, **roller_kwargs)
        actual = ep_roller.rollouts()
        _compare_rollout_batch(self, actual, expected)
示例#21
0
def test_trunc_basic_equivalence(stateful, state_tuple):
    """
    Test that TruncatedRoller is equivalent to BasicRoller
    for batches of one environment when the episodes end
    cleanly.
    """
    env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8')
    env = env_fn()
    model = SimpleModel(env.action_space.low.shape,
                        stateful=stateful,
                        state_tuple=state_tuple)
    basic_roller = BasicRoller(env, model, min_episodes=5)
    expected = basic_roller.rollouts()
    total_timesteps = sum([x.num_steps for x in expected])

    batched_env = batched_gym_env([env_fn], sync=True)
    trunc_roller = TruncatedRoller(batched_env, model, total_timesteps)
    actual = trunc_roller.rollouts()
    _compare_rollout_batch(actual, expected)
示例#22
0
 def test_multiple_batches(self):
     """
     Make sure calling rollouts multiple times works.
     """
     env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8')
     env = env_fn()
     try:
         model = SimpleModel(env.action_space.low.shape)
     finally:
         env.close()
     batched_env = batched_gym_env([env_fn], sync=True)
     try:
         ep_roller = EpisodeRoller(batched_env,
                                   model,
                                   min_episodes=5,
                                   min_steps=7)
         first = ep_roller.rollouts()
         for _ in range(3):
             _compare_rollout_batch(self, first, ep_roller.rollouts())
     finally:
         batched_env.close()
示例#23
0
def learn_pong():
    """Train an agent."""
    env = batched_gym_env([make_single_env] * NUM_WORKERS)
    try:
        agent = ActorCritic(gym_space_distribution(env.action_space),
                            gym_space_vectorizer(env.observation_space))
        with tf.Session() as sess:
            a2c = A2C(sess, agent, target_kl=TARGET_KL)
            roller = TruncatedRoller(env, agent, HORIZON)
            total_steps = 0
            rewards = []
            print("Training... Don't expect progress for ~400K steps.")
            while True:
                with agent.frozen():
                    rollouts = roller.rollouts()
                for rollout in rollouts:
                    total_steps += rollout.num_steps
                    if not rollout.trunc_end:
                        rewards.append(rollout.total_reward)
                agent.actor.extend(
                    a2c.policy_update(rollouts,
                                      POLICY_STEP,
                                      NUM_STEPS,
                                      min_leaf=MIN_LEAF,
                                      feature_frac=FEATURE_FRAC))
                agent.critic.extend(
                    a2c.value_update(rollouts,
                                     VALUE_STEP,
                                     NUM_STEPS,
                                     min_leaf=MIN_LEAF,
                                     feature_frac=FEATURE_FRAC))
                if rewards:
                    print(
                        '%d steps: mean=%f' %
                        (total_steps, sum(rewards[-10:]) / len(rewards[-10:])))
                else:
                    print('%d steps: no episodes complete yet' % total_steps)
    finally:
        env.close()
示例#24
0
def test_trunc_drop_states():
    """
    Test TruncatedRoller with drop_states=True.
    """
    env_fns = [
        lambda seed=x: SimpleEnv(seed, (5, 3), 'uint8') for x in range(15)
    ]
    model = SimpleModel((5, 3), stateful=True, state_tuple=True)

    expected_rollouts = []
    batched_env = batched_gym_env(env_fns, num_sub_batches=3, sync=True)
    trunc_roller = TruncatedRoller(batched_env, model, 17)
    for _ in range(3):
        expected_rollouts.extend(trunc_roller.rollouts())
    for rollout in expected_rollouts:
        for model_out in rollout.model_outs:
            model_out['states'] = None

    actual_rollouts = []
    trunc_roller = TruncatedRoller(batched_env, model, 17, drop_states=True)
    for _ in range(3):
        actual_rollouts.extend(trunc_roller.rollouts())

    _compare_rollout_batch(actual_rollouts, expected_rollouts)
示例#25
0
# Get our envs before we import tensorflow, incase they need their own tf instance
from multienv import getEnvFns
from anyrl.envs import batched_gym_env
env_fns = getEnvFns( bk2dir='data/record/' )
env = batched_gym_env( env_fns )

import sys
from rainbow import train

train( env, output_dir='/tmp/', num_steps=100000 )
示例#26
0
def getBatchedEnv(bk2dir=None):
    env_fns = getEnvFns(bk2dir=bk2dir)
    return batched_gym_env(env_fns)
示例#27
0
    def __init__(self, FLAGS):
        self.FLAGS = FLAGS
        self.is_training_ph = tf.placeholder(tf.bool, [])

        with rl_util.Timer('building_envs', self.FLAGS):
            fns = [make_fn(i, FLAGS) for i in range(FLAGS['num_envs'])]
            if self.FLAGS['num_envs'] == 1:
                self.env = fns[0]()
            else:
                self.env = batched_gym_env(fns)

        self.action_dist = rl_util.convert_to_dict_dist(
            self.env.action_space.spaces)
        self.obs_vectorizer = rl_util.convert_to_dict_dist(
            self.env.observation_space.spaces)

        # TF INIT STUFF
        # TODO: make these passable as cmd line params
        self.global_itr = tf.get_variable('global_itr',
                                          initializer=tf.constant(
                                              1, dtype=tf.int32),
                                          trainable=False)
        self.inc_global = tf.assign_add(self.global_itr,
                                        tf.constant(1, dtype=tf.int32))

        # sarsd phs
        in_batch_shape = (None, ) + self.obs_vectorizer.out_shape
        self.sarsd_phs = {}
        self.sarsd_phs['s'] = tf.placeholder(tf.float32,
                                             shape=in_batch_shape,
                                             name='s_ph')
        self.sarsd_phs['a'] = tf.placeholder(
            tf.float32, (None, ) + self.action_dist.out_shape,
            name='a_ph')  # actions that were taken
        self.sarsd_phs['s_next'] = tf.placeholder(tf.float32,
                                                  shape=in_batch_shape,
                                                  name='s_next_ph')
        self.sarsd_phs['r'] = tf.placeholder(tf.float32,
                                             shape=(None, ),
                                             name='r_ph')
        self.sarsd_phs['d'] = tf.placeholder(tf.float32,
                                             shape=(None, ),
                                             name='d_ph')

        self.embed_phs = {
            key: tf.placeholder(tf.float32,
                                shape=(None, self.FLAGS['embed_shape']),
                                name='{}_ph'.format(key))
            for key in self.FLAGS['embeds']
        }
        for key in ['a', 'r', 'd']:
            self.embed_phs[key] = self.sarsd_phs[key]

        # Pre-compute these transforms so we don't have to do it all the time
        # sarsd vals
        self.sarsd_vals = rl_util.sarsd_to_vals(self.sarsd_phs,
                                                self.obs_vectorizer,
                                                self.FLAGS)

        # ALGO SETUP
        Encoder = VAE if 'vae' in self.FLAGS['cnn_gn'] else DYN
        EName = 'VAE' if 'vae' in self.FLAGS['cnn_gn'] else 'DYN'

        if self.FLAGS['goal_dyn'] != '':
            self.goal_model = DYN(self.sarsd_vals,
                                  self.sarsd_phs,
                                  self.action_dist,
                                  self.obs_vectorizer,
                                  FLAGS,
                                  conv='cnn' in self.FLAGS['goal_dyn'],
                                  name='GoalDYN',
                                  compute_grad=False).model
        else:
            self.goal_model = None

        if self.FLAGS['aac']:
            self.value_encoder = Encoder(self.sarsd_vals,
                                         self.sarsd_phs,
                                         self.action_dist,
                                         self.obs_vectorizer,
                                         FLAGS,
                                         conv=False,
                                         name='Value' + EName)
            if self.FLAGS['value_goal']:
                self.goal_model = self.value_encoder.model
        else:
            self.value_encoder = None

        self.encoder = Encoder(self.sarsd_vals,
                               self.sarsd_phs,
                               self.action_dist,
                               self.obs_vectorizer,
                               FLAGS,
                               conv='cnn' in self.FLAGS['cnn_gn'],
                               goal_model=self.goal_model,
                               is_training_ph=self.is_training_ph)

        Agent = {'scripted': Scripted, 'sac': SAC}[self.FLAGS['agent']]
        self.agent = Agent(
            sas_vals=self.sarsd_vals,
            sas_phs=self.sarsd_phs,
            embed_phs=self.embed_phs,
            action_dist=self.action_dist,
            obs_vectorizer=self.obs_vectorizer,
            FLAGS=FLAGS,
            dyn=self.encoder if self.FLAGS['share_dyn'] else None,
            value_dyn=self.value_encoder if self.FLAGS['aac'] else None,
            is_training_ph=self.is_training_ph)

        if self.FLAGS['agent'] == 'sac':
            self.scripted_agent = Scripted(
                sas_vals=self.sarsd_vals,
                sas_phs=self.sarsd_phs,
                embed_phs=self.embed_phs,
                action_dist=self.action_dist,
                obs_vectorizer=self.obs_vectorizer,
                FLAGS=FLAGS,
                dyn=self.encoder if self.FLAGS['share_dyn'] else None)

        self.eval_vals = {}
        if self.FLAGS['grad_summaries']:
            self.eval_vals['summary'] = tf.summary.merge(
                self.encoder.grad_summaries)
        else:
            self.eval_vals['summary'] = self.encoder.eval_vals.pop('summary')
            #self.eval_vals['summary'] = tf.no_op()
        self.eval_vals.update(self.encoder.eval_vals)
        if self.FLAGS['aac']:
            self.eval_vals.update(
                prefix_vals('value', self.value_encoder.eval_vals))

        self.sess = get_session()
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10,
                                    keep_checkpoint_every_n_hours=0.5)
        self.train_writer = tf.summary.FileWriter(
            os.path.join(self.FLAGS['log_path'], 'tb'), self.sess.graph)

        # Load pre-trained variables from a path maybe.
        if self.FLAGS['load_path'] != '':

            def get_vars_to_load(path):
                load_names = [
                    name
                    for name, _ in tf.contrib.framework.list_variables(path)
                ]
                vars = [
                    var for var in tf.global_variables()
                    if var.name[:-2] in load_names and 'Adam' not in var.name
                ]
                return vars

            path = self.FLAGS['load_path']
            #import ipdb; ipdb.set_trace()
            loader = tf.train.Saver(var_list=get_vars_to_load(path))
            loader.restore(self.sess, path)
            print()
            print('Loaded trained variables from', path)
            print()

        # SUMMARY STUFF
        self.pred_plot_ph = tf.placeholder(tf.uint8)
        self.plot_summ = tf.summary.image('mdn_contour',
                                          self.pred_plot_ph,
                                          max_outputs=3)
        self.value_plot_summ = tf.summary.image('value_mdn_contour',
                                                self.pred_plot_ph,
                                                max_outputs=3)

        self.gif_paths_ph = tf.placeholder(tf.string,
                                           shape=(None, ),
                                           name='gif_path_ph')
        self.gif_summ = rl_util.gif_summary('rollout',
                                            self.gif_paths_ph,
                                            max_outputs=3)

        #self.sess.graph.finalize()
        # Make separate process for gif because it takes a long time
        self.gif_rollout_queue = Queue(maxsize=3)
        self.gif_path_queue = Queue(maxsize=3)
        # TODO: what is passed in a python process?
        if self.FLAGS['run_rl_optim']:
            self.gif_proc = Process(
                target=gif_plotter,
                daemon=True,
                args=(self.obs_vectorizer, self.action_dist, self.FLAGS,
                      self.gif_rollout_queue, self.gif_path_queue))
            self.gif_proc.start()

        self.mean_summ = defaultdict(lambda: None)
        self.mean_summ_phs = defaultdict(lambda: None)

        numvars()

        if self.FLAGS['threading']:
            # multi-thread for a moderate speed-up
            self.rollout_queue = queue.Queue(maxsize=3)
            self.rollout_thread = Thread(target=self.rollout_maker,
                                         daemon=True)
示例#28
0
def main():
    """
    Entry-point for the program.
    """
    args = _parse_args()

    # batched env = creates gym env, not sure what batched means
    # make_single_env = GrayscaleEnv > DownsampleEnv
    # GrayscaleEnv = turns RGB into grayscale
    # DownsampleEnv = down samples observation by N times where N is the specified variable (e.g. 2x smaller)
    env = batched_gym_env([partial(make_single_env, args.game)] * args.workers)
    env_test = make_single_env(args.game)
    #make_single_env(args.game)
    print('OBSSSS', env_test.observation_space)
    #env = CustomWrapper(args.game)
    # Using BatchedFrameStack with concat=False is more
    # memory efficient than other stacking options.
    env = BatchedFrameStack(env, num_images=4, concat=False)

    with tf.Session() as sess:

        def make_net(name):
            return rainbow_models(sess,
                                  env.action_space.n,
                                  gym_space_vectorizer(env.observation_space),
                                  min_val=-200,
                                  max_val=200)

        dqn = DQN(*rainbow_models(sess,
                                  env.action_space.n,
                                  gym_space_vectorizer(env.observation_space),
                                  min_val=-200,
                                  max_val=200))
        player = BatchedPlayer(env,
                               EpsGreedyQNetwork(dqn.online_net, args.epsilon))
        optimize = dqn.optimize(learning_rate=args.lr)

        sess.run(tf.global_variables_initializer())

        reward_hist = []
        total_steps = 0

        def _handle_ep(steps, rew):
            nonlocal total_steps
            total_steps += steps
            reward_hist.append(rew)
            if len(reward_hist) == REWARD_HISTORY:
                print('%d steps: mean=%f' %
                      (total_steps, sum(reward_hist) / len(reward_hist)))
                reward_hist.clear()

        dqn.train(num_steps=int(1e7),
                  player=player,
                  replay_buffer=UniformReplayBuffer(args.buffer_size),
                  optimize_op=optimize,
                  target_interval=args.target_interval,
                  batch_size=args.batch_size,
                  min_buffer_size=args.min_buffer_size,
                  handle_ep=_handle_ep)

    env.close()