コード例 #1
0
ファイル: test_players.py プロジェクト: decoderkurt/anyrl-py
def test_mixed_batch():
    """
    Test a batch with a bunch of different
    environments.
    """
    env_fns = [
        lambda s=seed: SimpleEnv(s, (1, 2, 3), 'float32')
        for seed in [3, 3, 3, 3, 3, 3]
    ]  #[5, 8, 1, 9, 3, 2]]
    make_agent = lambda: SimpleModel((1, 2, 3), stateful=True)
    for num_sub in [1, 2, 3]:
        batched_player = BatchedPlayer(
            batched_gym_env(env_fns, num_sub_batches=num_sub), make_agent(), 3)
        expected_eps = []
        for player in [
                BasicPlayer(env_fn(), make_agent(), 3) for env_fn in env_fns
        ]:
            transes = [t for _ in range(50) for t in player.play()]
            expected_eps.extend(_separate_episodes(transes))
        actual_transes = [t for _ in range(50) for t in batched_player.play()]
        actual_eps = _separate_episodes(actual_transes)
        assert len(expected_eps) == len(actual_eps)
        for episode in expected_eps:
            found = False
            for i, actual in enumerate(actual_eps):
                if _episodes_equivalent(episode, actual):
                    del actual_eps[i]
                    found = True
                    break
            assert found
コード例 #2
0
ファイル: test_players.py プロジェクト: decoderkurt/anyrl-py
def test_single_batch():
    """
    Test BatchedPlayer when the batch size is 1.
    """
    make_env = lambda: SimpleEnv(9, (1, 2, 3), 'float32')
    make_agent = lambda: SimpleModel((1, 2, 3), stateful=True)
    basic_player = BasicPlayer(make_env(), make_agent(), 3)
    batched_player = BatchedPlayer(batched_gym_env([make_env]), make_agent(),
                                   3)
    for _ in range(50):
        transes1 = basic_player.play()
        transes2 = batched_player.play()
        assert len(transes1) == len(transes2)
        for trans1, trans2 in zip(transes1, transes2):
            assert _transitions_equal(trans1, trans2)
コード例 #3
0
def main():
    """Run DQN until the environment throws an exception."""
    base_path = "results/rainbow/6/"
    env = make_env(stack=False, scale_rew=False, render=None, monitor=base_path + "train_monitor",
                   episodic_life=True)
    # I think the env itself allows Backtracking
    env = BatchedFrameStack(BatchedGymEnv([[env]]), num_images=4, concat=False)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.8

    with tf.Session(config=config) as sess:
        dqn = DQN(*rainbow_models(sess,
                                  env.action_space.n, gym_space_vectorizer(env.observation_space),
                                  min_val=-200, max_val=200))
        player = NStepPlayer(BatchedPlayer(env, dqn.online_net), 3)
        optimize = dqn.optimize(learning_rate=1e-4)
        saver = tf.train.Saver(name="rainbow")
        sess.run(tf.global_variables_initializer())
        saver.save(sess, base_path + "training", global_step=0)
        try:
            dqn.train(num_steps=2_000_000,  # Make sure an exception arrives before we stop.
                      player=player,
                      replay_buffer=PrioritizedReplayBuffer(500000, 0.5, 0.4, epsilon=0.1),
                      optimize_op=optimize,
                      train_interval=1,
                      target_interval=8192,
                      batch_size=64,
                      min_buffer_size=20000,
                      handle_ep=handle_ep)  # in seconds
        except KeyboardInterrupt:
            print("keyboard interrupt")
        print("finishing")
        saver.save(sess, base_path + "final", global_step=2_000_000)
コード例 #4
0
def main():
    """Run DQN until the environment throws an exception."""
    env = make(game='SonicTheHedgehog-Genesis', state='GreenHillZone.Act1')
    env = BatchedFrameStack(BatchedGymEnv([[env]]), num_images=4, concat=False)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # pylint: disable=E1101
    with tf.Session(config=config) as sess:
        dqn = DQN(*rainbow_models(sess,
                                  env.action_space.n,
                                  gym_space_vectorizer(env.observation_space),
                                  min_val=-200,
                                  max_val=200))
        player = NStepPlayer(BatchedPlayer(env, dqn.online_net), 3)
        optimize = dqn.optimize(learning_rate=1e-4)
        sess.run(tf.global_variables_initializer())
        dqn.train(
            num_steps=2000000,  # Make sure an exception arrives before we stop.
            player=player,
            replay_buffer=PrioritizedReplayBuffer(500000,
                                                  0.5,
                                                  0.4,
                                                  epsilon=0.1),
            optimize_op=optimize,
            train_interval=1,
            target_interval=8192,
            batch_size=32,
            min_buffer_size=20000)
コード例 #5
0
def main():
    """Run DQN until the environment throws an exception."""
    env = AllowBacktracking(make_env(stack=False, scale_rew=False))
    env = BatchedFrameStack(BatchedGymEnv([[env]]), num_images=4, concat=False)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # pylint: disable=E1101
    with tf.Session(config=config) as sess:
        dqn = DQN(*rainbow_models(sess,
                                  env.action_space.n,
                                  gym_space_vectorizer(env.observation_space),
                                  min_val=-200,
                                  max_val=200))
        player = NStepPlayer(BatchedPlayer(env, dqn.online_net), 3)
        optimize = dqn.optimize(learning_rate=1e-4)
        sess.run(tf.global_variables_initializer())
        dqn.train(
            num_steps=2000000,  # Make sure an exception arrives before we stop.
            player=player,
            replay_buffer=StochasticMaxStochasticDeltaDeletionPRB(500000,
                                                                  0.5,
                                                                  0.4,
                                                                  epsilon=0.1),
            optimize_op=optimize,
            train_interval=1,
            target_interval=8192,
            batch_size=32,
            min_buffer_size=20000)
コード例 #6
0
def main():
    env = AllowBacktracking(make_env(stack=False, scale_rew=False))
    env = BatchedFrameStack(BatchedGymEnv([[env]]), num_images=4, concat=False)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        dqn = DQN(*rainbow_models(sess,
                                  env.action_space.n,
                                  gym_space_vectorizer(env.observation_space),
                                  min_val=-421,
                                  max_val=421))
        player = NStepPlayer(BatchedPlayer(env, dqn.online_net), 3)
        optimize = dqn.optimize(learning_rate=1e-4)
        sess.run(tf.global_variables_initializer())
        dqn.train(num_steps=2000000,
                  player=player,
                  replay_buffer=PrioritizedReplayBuffer(500000,
                                                        0.5,
                                                        0.4,
                                                        epsilon=0.1),
                  optimize_op=optimize,
                  train_interval=1,
                  target_interval=64,
                  batch_size=32,
                  min_buffer_size=25000)
コード例 #7
0
ファイル: rainbow_local.py プロジェクト: lbertge/retro-noob
def main():
    """Run DQN until the environment throws an exception."""
    env = make(game='SonicTheHedgehog-Genesis', state='GreenHillZone.Act1')
    env = AllowBacktracking(make_local_env(env, stack=False, scale_rew=False))
    env = BatchedFrameStack(BatchedGymEnv([[env]]), num_images=4, concat=False)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True # pylint: disable=E1101
    with tf.Session(config=config) as sess:
        dqn = DQN(*rainbow_models(sess,
                                  env.action_space.n,
                                  gym_space_vectorizer(env.observation_space),
                                  min_val=-200,
                                  max_val=200))
        player = NStepPlayer(BatchedPlayer(env, dqn.online_net), 3)
        optimize = dqn.optimize(learning_rate=1e-4)
        sess.run(tf.global_variables_initializer())
        dqn.train(num_steps=num_steps, # Make sure an exception arrives before we stop.
                  player=player,
                  replay_buffer=PrioritizedReplayBuffer(500000, 0.5, 0.4, epsilon=0.1),
                  optimize_op=optimize,
                  train_interval=1,
                  target_interval=8192,
                  batch_size=32,
                  min_buffer_size=20000)

        print(tf.trainable_variables())
        save_path='/home/noob/retro-noob/rainbow/params/params'
        utils.save_state(save_path+'_tf_saver')

        with tf.variable_scope('model'):
            params = tf.trainable_variables()

        ps = sess.run(params)
        joblib.dump(ps, save_path + '_joblib')
コード例 #8
0
def main():
    """Run DQN until the environment throws an exception."""
    env = AllowBacktracking(make_env(stack=False, scale_rew=False))
    env = BatchedFrameStack(BatchedGymEnv([[env]]), num_images=4, concat=False)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # pylint: disable=E1101
    with tf.Session(config=config) as sess:
        dqn = DQN(*rainbow_models(sess,
                                  env.action_space.n,
                                  gym_space_vectorizer(env.observation_space),
                                  min_val=-200,
                                  max_val=200))
        player = NStepPlayer(BatchedPlayer(env, dqn.online_net), 3)
        """
        Create a TF Op that optimizes the objective.
        Args:
          learning_rate: the Adam learning rate.
          epsilon: the Adam epsilon.
        """
        optimize = dqn.optimize(learning_rate=6.25e-5, epsilon=1.5e-4)

        sess.run(tf.global_variables_initializer())
        """
        Run an automated training loop.
        This is meant to provide a convenient way to run a
        standard training loop without any modifications.
        You may get more flexibility by writing your own
        training loop.
        Args:
          num_steps: the number of timesteps to run.
          player: the Player for gathering experience.
          replay_buffer: the ReplayBuffer for experience.
          optimize_op: a TF Op to optimize the model.
          train_interval: timesteps per training step.
          target_interval: number of timesteps between
            target network updates.
          batch_size: the size of experience mini-batches.
          min_buffer_size: minimum replay buffer size
            before training is performed.
          tf_schedules: a sequence of TFSchedules that are
            updated with the number of steps taken.
          handle_ep: called with information about every
            completed episode.
          timeout: if set, this is a number of seconds
            after which the training loop should exit.
        """
        dqn.train(
            num_steps=1000000,  # Make sure an exception arrives before we stop.
            player=player,
            replay_buffer=PrioritizedReplayBuffer(500000,
                                                  0.5,
                                                  0.4,
                                                  epsilon=0.1),
            optimize_op=optimize,
            train_interval=1,
            target_interval=8192,
            batch_size=32,
            min_buffer_size=20000)
コード例 #9
0
ファイル: anyrl_builders.py プロジェクト: antkve/rl-zoo
 def finish(self, sess, dqn):
     env = BatchedGymEnv([[self.env]])
     return {
         "player": NStepPlayer(BatchedPlayer(self.env, dqn.online_net), 3),
         "optimize_op": dqn.optimize(learning_rate=0.002),
         "replay_buffer": PrioritizedReplayBuffer(20000,
                                                  0.5,
                                                  0.4,
                                                  epsilon=0.2),
     }
コード例 #10
0
def main():
    """Run DQN until the environment throws an exception."""
    # "results/rainbow/2/videos/6"
    save_dir = "results/rainbow/7/val_monitor/2"
    env = make_env(stack=False,
                   scale_rew=False,
                   render=60,
                   monitor=save_dir,
                   timelimit=False,
                   episodic_life=False,
                   single_life=True,
                   video=lambda id: True)
    # env = AllowBacktracking(make_env(stack=False, scale_rew=False))
    env = BatchedFrameStack(BatchedGymEnv([[env]]), num_images=4, concat=False)
    config = tf.ConfigProto()

    with tf.Session(config=config) as sess:
        saver = tf.train.import_meta_graph(
            "results/rainbow/7/final-4000000.meta", clear_devices=True)
        # saver.restore(sess, tf.train.latest_checkpoint('results/rainbow/2'))
        saver.restore(sess, 'results/rainbow/7/final-4000000')
        model = LoadedNetwork(sess,
                              gym_space_vectorizer(env.observation_space))
        # rebuild the online_net form the saved model
        # type <anyrl.models.dqn_dist.NatureDistQNetwork object at ???>
        player = NStepPlayer(BatchedPlayer(env, model), 3)

        with tf.device("/cpu"):
            # sess.run(tf.global_variables_initializer())
            try:
                for episode_index in tqdm(range(40), unit="episode"):
                    axes = make_axes()
                    plotter = RewardPlotter(axes,
                                            save_period=40,
                                            render_period=600,
                                            max_entries=600)
                    for i in count():
                        trajectories = player.play()
                        end_of_episode = False
                        current_total_reward = None
                        for trajectory in trajectories:
                            current_total_reward = trajectory["total_reward"]
                            if trajectory["is_last"]:
                                end_of_episode = True
                        plotter.update(current_total_reward, step=i)
                        if end_of_episode:
                            # plt.show()
                            plotter.render()
                            plotter.save_file("{}/e{}.pdf".format(
                                save_dir, episode_index))
                            plotter.close()
                            break
            except KeyboardInterrupt:
                env.close()
                plt.close()
コード例 #11
0
def main():
    """Run DQN until the environment throws an exception."""
    env = AllowBacktracking(make_env(stack=False, scale_rew=False))
    env = BatchedFrameStack(BatchedGymEnv([[env]]), num_images=4, concat=False)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # pylint: disable=E1101
    with tf.Session(config=config) as sess:
        dqn = DQN(*rainbow_models(sess,
                                  env.action_space.n,
                                  gym_space_vectorizer(env.observation_space),
                                  min_val=-200,
                                  max_val=200))
        player = NStepPlayer(BatchedPlayer(env, dqn.online_net), 3)

        # Other exploration schedules
        #eps_decay_sched = LinearTFSchedule(50000, 1.0, 0.01)
        #player = NStepPlayer(BatchedPlayer(env, EpsGreedyQNetwork(dqn.online_net, 0.1)), 3)
        #player = NStepPlayer(BatchedPlayer(env, EpsGreedyQNetwork(dqn.online_net, TFScheduleValue(sess, eps_decay_sched))), 3)
        #player = NStepPlayer(BatchedPlayer(env, SonicEpsGreedyQNetwork(dqn.online_net, TFScheduleValue(sess, eps_decay_sched))), 3)

        optimize = dqn.optimize(learning_rate=1e-4)
        sess.run(tf.global_variables_initializer())

        reward_hist = []
        total_steps = 0

        def _handle_ep(steps, rew, env_rewards):
            nonlocal total_steps
            total_steps += steps
            reward_hist.append(rew)
            if total_steps % 10 == 0:
                print('%d episodes, %d steps: mean of last 100 episodes=%f' %
                      (len(reward_hist), total_steps,
                       sum(reward_hist[-100:]) / len(reward_hist[-100:])))

        dqn.train(
            num_steps=2000000,  # Make sure an exception arrives before we stop.
            player=player,
            replay_buffer=PrioritizedReplayBuffer(500000,
                                                  0.5,
                                                  0.4,
                                                  epsilon=0.1),
            optimize_op=optimize,
            train_interval=1,
            target_interval=8192,
            batch_size=32,
            min_buffer_size=20000,
            tf_schedules=[eps_decay_sched],
            handle_ep=_handle_ep,
            restore_path='./pretrained_model',
            save_interval=None,
        )
コード例 #12
0
def main():

    env_name = 'MineRLNavigateDense-v0'
    """Run DQN until the environment throws an exception."""
    base_env = [SimpleNavigateEnvWrapper(get_env(env_name)) for _ in range(1)]
    env = BatchedFrameStack(BatchedGymEnv([base_env]),
                            num_images=4,
                            concat=True)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # pylint: disable=E1101
    with tf.Session(config=config) as sess:
        online, target = mine_rainbow_online_target(mine_cnn,
                                                    sess,
                                                    env.action_space.n,
                                                    gym_space_vectorizer(
                                                        env.observation_space),
                                                    min_val=-200,
                                                    max_val=200)
        dqn = DQN(online, target)
        player = NStepPlayer(BatchedPlayer(env, dqn.online_net), 3)
        optimize = dqn.optimize(learning_rate=1e-4)
        sess.run(tf.global_variables_initializer())

        buffer_capacity = 5000

        replay_buffer = PrioritizedReplayBuffer(buffer_capacity,
                                                0.5,
                                                0.4,
                                                epsilon=0.1)

        iter = non_bugged_data_arr(env_name, num_trajs=100)
        expert_player = NStepPlayer(ImitationPlayer(iter, 200), 3)

        for traj in expert_player.play():
            replay_buffer.add_sample(traj, init_weight=1)

        print('starting training')
        dqn.train(num_steps=200,
                  player=player,
                  replay_buffer=replay_buffer,
                  optimize_op=optimize,
                  train_interval=1,
                  target_interval=8192,
                  batch_size=32,
                  min_buffer_size=20000)

        print('starting eval')
        player._cur_states = None
        score = evaluate(player)
        print(score)
コード例 #13
0
def main():
    """
    Entry-point for the program.
    """
    args = _parse_args()
    env = batched_gym_env([partial(make_single_env, args.game)] * args.workers)

    # Using BatchedFrameStack with concat=False is more
    # memory efficient than other stacking options.
    env = BatchedFrameStack(env, num_images=4, concat=False)

    with tf.Session() as sess:

        def make_net(name):
            return NatureQNetwork(sess,
                                  env.action_space.n,
                                  gym_space_vectorizer(env.observation_space),
                                  name,
                                  dueling=True)

        dqn = DQN(make_net('online'), make_net('target'))
        player = BatchedPlayer(env,
                               EpsGreedyQNetwork(dqn.online_net, args.epsilon))
        optimize = dqn.optimize(learning_rate=args.lr)

        sess.run(tf.global_variables_initializer())

        reward_hist = []
        total_steps = 0

        def _handle_ep(steps, rew):
            nonlocal total_steps
            total_steps += steps
            reward_hist.append(rew)
            if len(reward_hist) == REWARD_HISTORY:
                print('%d steps: mean=%f' %
                      (total_steps, sum(reward_hist) / len(reward_hist)))
                reward_hist.clear()

        dqn.train(num_steps=int(1e7),
                  player=player,
                  replay_buffer=UniformReplayBuffer(args.buffer_size),
                  optimize_op=optimize,
                  target_interval=args.target_interval,
                  batch_size=args.batch_size,
                  min_buffer_size=args.min_buffer_size,
                  handle_ep=_handle_ep)

    env.close()
コード例 #14
0
def main():
    """Run DQN until the environment throws an exception."""
    env_fns, env_names = create_envs()
    env = BatchedFrameStack(batched_gym_env(env_fns),
                            num_images=4,
                            concat=False)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # pylint: disable=E1101
    with tf.Session(config=config) as sess:
        dqn = DQN(*rainbow_models(sess,
                                  env.action_space.n,
                                  gym_space_vectorizer(env.observation_space),
                                  min_val=-200,
                                  max_val=200))
        player = NStepPlayer(BatchedPlayer(env, dqn.online_net), 3)
        optimize = dqn.optimize(learning_rate=1e-4)  # Use ADAM
        sess.run(tf.global_variables_initializer())

        reward_hist = []
        total_steps = 0

        def _handle_ep(steps, rew, env_rewards):
            nonlocal total_steps
            total_steps += steps
            reward_hist.append(rew)
            if total_steps % 1 == 0:
                print('%d episodes, %d steps: mean of last 100 episodes=%f' %
                      (len(reward_hist), total_steps,
                       sum(reward_hist[-100:]) / len(reward_hist[-100:])))

        dqn.train(
            num_steps=
            2000000000,  # Make sure an exception arrives before we stop.
            player=player,
            replay_buffer=PrioritizedReplayBuffer(500000,
                                                  0.5,
                                                  0.4,
                                                  epsilon=0.1),
            optimize_op=optimize,
            train_interval=1,
            target_interval=8192,
            batch_size=32,
            min_buffer_size=20000,
            handle_ep=_handle_ep,
            num_envs=len(env_fns),
            save_interval=10,
        )
コード例 #15
0
ファイル: rainbow_run_random.py プロジェクト: PeerM/starman
def main():
    """Run DQN until the environment throws an exception."""
    # "results/rainbow/2/videos/6"
    env = make_env(stack=False,
                   scale_rew=False,
                   render=20,
                   monitor=None,
                   timelimit=False)
    # env = AllowBacktracking(make_env(stack=False, scale_rew=False))
    # TODO we might not want to allow backtracking, it kinda hurts in mario
    env = BatchedFrameStack(BatchedGymEnv([[env]]), num_images=4, concat=False)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # pylint: disable=E1101
    config.gpu_options.per_process_gpu_memory_fraction = 0.6

    with tf.Session(config=config) as sess:
        dqn = DQN(*rainbow_models(sess,
                                  env.action_space.n,
                                  gym_space_vectorizer(env.observation_space),
                                  min_val=-200,
                                  max_val=200))
        # TODO rebuild the online_net form the saved model
        # type <anyrl.models.dqn_dist.NatureDistQNetwork object at ???>
        # important methods
        #
        model = dqn.online_net
        player = NStepPlayer(BatchedPlayer(env, dqn.online_net), 3)

        with tf.device("/cpu"):
            # sess.run(tf.global_variables_initializer())

            vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
            try:
                for i in tqdm(range(100000)):
                    trajectories = player.play()
                    for trajectori in trajectories:
                        trajectori
                        pass
            except KeyboardInterrupt:
                env.close()
コード例 #16
0
def main():
    """Run DQN until the environment throws an exception."""
    env = AllowBacktracking(make_env(stack=False, scale_rew=False))
    env = BatchedFrameStack(BatchedGymEnv([[env]]), num_images=4, concat=False)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # pylint: disable=E1101

    with tf.Session(config=config) as sess:
        dqn = DQN(*rainbow_models(sess,
                                  env.action_space.n,
                                  gym_space_vectorizer(env.observation_space),
                                  min_val=-200,
                                  max_val=200))
        player = NStepPlayer(BatchedPlayer(env, dqn.online_net), 3)
        optimize = dqn.optimize(learning_rate=1e-4)
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver()
        saver.restore(sess, "/root/compo/model.ckpt")
        #print('model restored')
        replay_buffer = pickle.load(
            gzip.open('/root/compo/replay_buffer.p.gz', 'rb'))
        replay_buffer.alpha = 0.2
        replay_buffer.beta = 0.4
        replay_buffer.capacity = 100000

        restore_ppo2_weights(sess)

        dqn.train(
            num_steps=2000000,  # Make sure an exception arrives before we stop.
            player=player,
            replay_buffer=
            replay_buffer,  #PrioritizedReplayBuffer(500000, 0.5, 0.4, epsilon=0.1),
            optimize_op=optimize,
            train_interval=4,
            target_interval=8192,
            batch_size=32,
            min_buffer_size=20000)
コード例 #17
0
def main():
    """Run DQN until the environment throws an exception."""

    print('creating env')

    env = AllowBacktracking(make_env(stack=False, scale_rew=False))

    env = BatchedFrameStack(BatchedGymEnv([[env]]), num_images=4, concat=False)

    config = tf.ConfigProto()

    config.gpu_options.allow_growth = True  # pylint: disable=E1101

    print('starting tf session')

    with tf.Session(config=config) as sess:

        print('creating agent')

        online_net, target_net = rainbow_models(sess,
                                                env.action_space.n,
                                                gym_space_vectorizer(
                                                    env.observation_space),
                                                min_val=-200,
                                                max_val=200)

        dqn = DQN(online_net, target_net)

        player = NStepPlayer(BatchedPlayer(env, dqn.online_net), 3)

        optimize = dqn.optimize(learning_rate=1e-4)

        saver = tf.train.Saver()

        sess.run(tf.global_variables_initializer())

        train_steps = 5000

        print('training steps:', train_steps)

        for j in range(1):

            print(j)

            start = time.time()

            dqn.train(
                num_steps=
                train_steps,  # Make sure an exception arrives before we stop.
                player=player,
                replay_buffer=PrioritizedReplayBuffer(500000,
                                                      0.5,
                                                      0.4,
                                                      epsilon=0.1),
                optimize_op=optimize,
                train_interval=1,
                target_interval=8192,
                batch_size=32,
                min_buffer_size=10000)

            end = time.time()

            print(end - start)

        print('done training')

        print('save nn')

        save_path = saver.save(sess, "saved_models/rainbow5.ckpt")
        print("Model saved in path: %s" % save_path)

        tvars = tf.trainable_variables()
        tvars_vals = sess.run(tvars)

        #for var, val in zip(tvars, tvars_vals):
        #    print(var.name, val[0])

        #print(tvars_vals[0][-5:])

        #print('stepping')

        #obs = env.reset()

        #online_net.step(obs, obs)
        '''
コード例 #18
0
def main():
    """
    Entry-point for the program.
    """
    args = _parse_args()

    # batched env = creates gym env, not sure what batched means
    # make_single_env = GrayscaleEnv > DownsampleEnv
    # GrayscaleEnv = turns RGB into grayscale
    # DownsampleEnv = down samples observation by N times where N is the specified variable (e.g. 2x smaller)
    env = batched_gym_env([partial(make_single_env, args.game)] * args.workers)
    env_test = make_single_env(args.game)
    #make_single_env(args.game)
    print('OBSSSS', env_test.observation_space)
    #env = CustomWrapper(args.game)
    # Using BatchedFrameStack with concat=False is more
    # memory efficient than other stacking options.
    env = BatchedFrameStack(env, num_images=4, concat=False)

    with tf.Session() as sess:

        def make_net(name):
            return rainbow_models(sess,
                                  env.action_space.n,
                                  gym_space_vectorizer(env.observation_space),
                                  min_val=-200,
                                  max_val=200)

        dqn = DQN(*rainbow_models(sess,
                                  env.action_space.n,
                                  gym_space_vectorizer(env.observation_space),
                                  min_val=-200,
                                  max_val=200))
        player = BatchedPlayer(env,
                               EpsGreedyQNetwork(dqn.online_net, args.epsilon))
        optimize = dqn.optimize(learning_rate=args.lr)

        sess.run(tf.global_variables_initializer())

        reward_hist = []
        total_steps = 0

        def _handle_ep(steps, rew):
            nonlocal total_steps
            total_steps += steps
            reward_hist.append(rew)
            if len(reward_hist) == REWARD_HISTORY:
                print('%d steps: mean=%f' %
                      (total_steps, sum(reward_hist) / len(reward_hist)))
                reward_hist.clear()

        dqn.train(num_steps=int(1e7),
                  player=player,
                  replay_buffer=UniformReplayBuffer(args.buffer_size),
                  optimize_op=optimize,
                  target_interval=args.target_interval,
                  batch_size=args.batch_size,
                  min_buffer_size=args.min_buffer_size,
                  handle_ep=_handle_ep)

    env.close()
コード例 #19
0
    train_steps = 1000  #200000

    for i in range(3):

        stage = random.choice(stages)
        game = random.choice(games)

        print('creating env')
        env = AllowBacktracking(
            make_env_multi(game, stage, stack=False, scale_rew=False))
        env = BatchedFrameStack(BatchedGymEnv([[env]]),
                                num_images=4,
                                concat=False)

        player = NStepPlayer(BatchedPlayer(env, dqn.online_net), 3)

        print(i, game, stage)
        print('training steps:', train_steps)

        start = time.time()

        dqn.train(
            num_steps=
            train_steps,  # Make sure an exception arrives before we stop.
            player=player,
            replay_buffer=PrioritizedReplayBuffer(500000,
                                                  0.5,
                                                  0.4,
                                                  epsilon=0.1),
            optimize_op=optimize,
コード例 #20
0
ファイル: agent.py プロジェクト: ichaelm/ShrubPig
def main():
    """Run DQN until the environment throws an exception."""
    envs = make_envs(stack=False, scale_rew=False)
    for i in range(len(envs)):
        envs[i] = AllowBacktracking(envs[i])
        envs[i] = BatchedFrameStack(BatchedGymEnv([[envs[i]]]),
                                    num_images=4,
                                    concat=False)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # pylint: disable=E1101
    with tf.Session(config=config) as sess:
        online_model, target_model = rainbow_models(
            sess,
            envs[0].action_space.n,
            gym_space_vectorizer(envs[0].observation_space),
            min_val=-200,
            max_val=200)
        replay_buffer = PrioritizedReplayBuffer(400000, 0.5, 0.4, epsilon=0.1)
        dqn = DQN(online_model, target_model)
        players = []
        for env in envs:
            player = NStepPlayer(BatchedPlayer(env, dqn.online_net), 3)
            players.append(player)
        optimize = dqn.optimize(learning_rate=1e-4)
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            saver = tf.train.Saver([
                tf.get_variable(name) for name in [
                    'online/layer_1/conv2d/kernel',
                    'online/layer_1/conv2d/bias',
                    'online/layer_2/conv2d/kernel',
                    'online/layer_2/conv2d/bias',
                    'online/layer_3/conv2d/kernel',
                    'online/layer_3/conv2d/bias',
                    'target/layer_1/conv2d/kernel',
                    'target/layer_1/conv2d/bias',
                    'target/layer_2/conv2d/kernel',
                    'target/layer_2/conv2d/bias',
                    'target/layer_3/conv2d/kernel',
                    'target/layer_3/conv2d/bias',
                ]
            ])
            # or
            """
          sess.run(tf.variables_initializer([tf.get_variable(name) for name in [
            'online/noisy_layer/weight_mu',
            'online/noisy_layer/bias_mu',
            'online/noisy_layer/weight_sigma',
            'online/noisy_layer/bias_sigma',
            'online/noisy_layer_1/weight_mu',
            'online/noisy_layer_1/bias_mu',
            'online/noisy_layer_1/weight_sigma',
            'online/noisy_layer_1/bias_sigma',
            'online/noisy_layer_2/weight_mu',
            'online/noisy_layer_2/bias_mu',
            'online/noisy_layer_2/weight_sigma',
            'online/noisy_layer_2/bias_sigma',
            'target/noisy_layer/weight_mu',
            'target/noisy_layer/bias_mu',
            'target/noisy_layer/weight_sigma',
            'target/noisy_layer/bias_sigma',
            'target/noisy_layer_1/weight_mu',
            'target/noisy_layer_1/bias_mu',
            'target/noisy_layer_1/weight_sigma',
            'target/noisy_layer_1/bias_sigma',
            'target/noisy_layer_2/weight_mu',
            'target/noisy_layer_2/bias_mu',
            'target/noisy_layer_2/weight_sigma',
            'target/noisy_layer_2/bias_sigma',
              'beta1_power',
              'beta2_power',
              'online/layer_1/conv2d/kernel/Adam',
              'online/layer_1/conv2d/kernel/Adam_1',
              'online/layer_1/conv2d/bias/Adam',
              'online/layer_1/conv2d/bias/Adam_1',
              'online/layer_2/conv2d/kernel/Adam',
              'online/layer_2/conv2d/kernel/Adam_1',
              'online/layer_2/conv2d/bias/Adam',
              'online/layer_2/conv2d/bias/Adam_1',
              'online/layer_3/conv2d/kernel/Adam',
              'online/layer_3/conv2d/kernel/Adam_1',
              'online/layer_3/conv2d/bias/Adam',
              'online/layer_3/conv2d/bias/Adam_1',
              'online/noisy_layer/weight_mu/Adam',
              'online/noisy_layer/weight_mu/Adam_1',
              'online/noisy_layer/bias_mu/Adam',
              'online/noisy_layer/bias_mu/Adam_1',
              'online/noisy_layer/weight_sigma/Adam',
              'online/noisy_layer/weight_sigma/Adam_1',
              'online/noisy_layer/bias_sigma/Adam',
              'online/noisy_layer/bias_sigma/Adam_1',
              'online/noisy_layer_1/weight_mu/Adam',
              'online/noisy_layer_1/weight_mu/Adam_1',
              'online/noisy_layer_1/bias_mu/Adam',
              'online/noisy_layer_1/bias_mu/Adam_1',
              'online/noisy_layer_1/weight_sigma/Adam',
              'online/noisy_layer_1/weight_sigma/Adam_1',
              'online/noisy_layer_1/bias_sigma/Adam',
              'online/noisy_layer_1/bias_sigma/Adam_1',
              'online/noisy_layer_2/weight_mu/Adam',
              'online/noisy_layer_2/weight_mu/Adam_1',
              'online/noisy_layer_2/bias_mu/Adam',
              'online/noisy_layer_2/bias_mu/Adam_1',
              'online/noisy_layer_2/weight_sigma/Adam',
              'online/noisy_layer_2/weight_sigma/Adam_1',
              'online/noisy_layer_2/bias_sigma/Adam',
              'online/noisy_layer_2/bias_sigma/Adam_1',
          ]]))
          """
            #sess.run( tf.initialize_variables( list( tf.get_variable(name) for name in sess.run( tf.report_uninitialized_variables( tf.all_variables( ) ) ) ) ) )
            sess.run(tf.global_variables_initializer())
            # either
            saver.restore(sess, '/root/compo/model')
            # end either
        for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
            print(i.name)
        while True:
            dqn.train(num_steps=16384,
                      players=players,
                      replay_buffer=replay_buffer,
                      optimize_op=optimize,
                      train_interval=1,
                      target_interval=8192,
                      batch_size=32,
                      min_buffer_size=20000)
            saver.save(sess, '/root/compo/out/model')
コード例 #21
0
def train(batched_env,
          env_count=1,
          batch_size_multiplier=32,
          num_steps=2000000,
          pretrained_model='artifacts/model/model.cpkt',
          output_dir='artifacts/model',
          use_schedules=True):
    """
    Trains on a batched_env using anyrl-py's dqn and rainbow model.

    env_count: The number of envs in batched_env
    batch_size_multiplier: batch_size of the dqn train call will be env_count * batch_size_multiplier
    num_steps: The number of steps to run training for
    pretrained_model: Load tf weights from this model file
    output_dir: Save tf weights to this file
    use_schedules: Enables the tf_schedules for the train call. Schedules require internet access, so don't include on
        retro-contest evaluation server
    """
    env = CollisionMapWrapper(batched_env)
    env = BatchedResizeImageWrapper(env)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # pylint: disable=E1101

    with tf.Session(config=config) as sess:
        dqn = DQN(*rainbow_models(sess,
                                  env.action_space.n,
                                  gym_space_vectorizer(env.observation_space),
                                  min_val=-200,
                                  max_val=200))

        scheduled_saver = ScheduledSaver(save_interval=10000,
                                         save_dir=output_dir)
        print('Outputting trained model to', output_dir)

        # Reporting uses BatchedPlayer to get _total_rewards
        batched_player = BatchedPlayer(env, dqn.online_net)
        player = NStepPlayer(batched_player, 3)

        optimize = dqn.optimize(learning_rate=1e-4)

        if pretrained_model is None:
            print('Initializing with random weights')
            sess.run(tf.global_variables_initializer())
        else:
            print('Loading pre-trained model from', pretrained_model)
            scheduled_saver.saver.restore(sess, pretrained_model)

        print('Beginning Training, steps', num_steps)

        tf_schedules = []

        if (use_schedules):
            tf_schedules = [
                scheduled_saver,
                LosswiseSchedule(num_steps, batched_player),
                LoadingBar(num_steps)
            ]

        print(env_count * batch_size_multiplier)

        dqn.train(
            num_steps=num_steps,
            player=player,
            replay_buffer=PrioritizedReplayBuffer(300000,
                                                  0.5,
                                                  0.4,
                                                  epsilon=0.1),
            optimize_op=optimize,
            train_interval=env_count,
            target_interval=8192,
            batch_size=env_count * batch_size_multiplier,
            min_buffer_size=max(4500, env_count * batch_size_multiplier),
            # min_buffer_size=60,
            tf_schedules=tf_schedules,
            handle_ep=print)
        scheduled_saver.save(sess)
コード例 #22
0
def main():
    """Run DQN until the environment throws an exception."""
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # pylint: disable=E1101

    comm = MPI.COMM_WORLD

    # Use MPI for parallel evaluation
    rank = comm.Get_rank()
    size = comm.Get_size()

    env_fns, env_names = create_eval_envs()

    env = AllowBacktracking(env_fns[rank](stack=False, scale_rew=False))
    env = BatchedFrameStack(BatchedGymEnv([[env]]), num_images=4, concat=False)
    with tf.Session(config=config) as sess:
        dqn = DQN(*rainbow_models(sess,
                                  env.action_space.n,
                                  gym_space_vectorizer(env.observation_space),
                                  min_val=-200,
                                  max_val=200))
        player = NStepPlayer(BatchedPlayer(env, dqn.online_net), 3)
        optimize = dqn.optimize(learning_rate=1e-4)
        sess.run(tf.global_variables_initializer())

        reward_hist = []
        total_steps = 0

        def _handle_ep(steps, rew, env_rewards):
            nonlocal total_steps
            total_steps += steps
            reward_hist.append(rew)
            if total_steps % 1 == 0:
                avg_score = sum(reward_hist[-100:]) / len(reward_hist[-100:])

# Global Score
            global_score = np.zeros(1)
            local_score = np.array(avg_score)
            print("Local Score for " + env_names[rank] + " at episode " +
                  str(len(reward_hist)) + " with timesteps: " +
                  str(total_steps) + ": " + str(local_score))
            comm.Allreduce(local_score, global_score, op=MPI.SUM)
            global_score /= size
            if rank == 0:
                print("Global Average Score at episode: " +
                      str(len(reward_hist)) + ": " + str(global_score))

        dqn.train(
            num_steps=2000000,  # Make sure an exception arrives before we stop.
            player=player,
            replay_buffer=PrioritizedReplayBuffer(500000,
                                                  0.5,
                                                  0.4,
                                                  epsilon=0.1),
            optimize_op=optimize,
            train_interval=1,
            target_interval=8192,
            batch_size=32,
            min_buffer_size=20000,
            handle_ep=_handle_ep,
            save_interval=None,
            restore_path=
            './checkpoints_rainbow/model-10'  # Model to be evaluated
        )
def main():
    """Run DQN until the environment throws an exception."""
    #env = AllowBacktracking(make_env(stack=False, scale_rew=False))
    #envs = make_training_envs()
    #env = BatchedFrameStack(BatchedGymEnv(envs), num_images=4, concat=False)
    #env = BatchedFrameStack(BatchedGymEnv([[env]]), num_images=4, concat=False)

    envs = get_training_envs()
    game, state = random.choice(envs)
    env = make_training_env(game, state, stack=False, scale_rew=False)
    env = prep_env(env)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # pylint: disable=E1101
    with tf.Session(config=config) as sess:
        dqn = DQN(*models(sess,
                          env.action_space.n,
                          gym_space_vectorizer(env.observation_space),
                          min_val=-200,
                          max_val=200))
        player = NStepPlayer(BatchedPlayer(env, dqn.online_net), 3)
        optimize = dqn.optimize(learning_rate=1e-4)
        loss = dqn.loss
        train_writer = tf.summary.FileWriter('./logs/multiple/train',
                                             sess.graph)
        tf.summary.scalar("loss", loss)
        reward = tf.Variable(0., name='reward', trainable=False)
        tf.summary.scalar('reward', tf.reduce_mean(reward))
        steps = tf.Variable(0, name='steps', trainable=False)
        tf.summary.scalar('steps', tf.reduce_mean(steps))
        summary_op = tf.summary.merge_all()
        sess.run(tf.global_variables_initializer())
        print(tf.trainable_variables())

        #graph = tf.get_default_graph()
        #restore_saver = tf.train.Saver({
        #    'dense1/bias': graph.get_tensor_by_name('online/dense1/bias:0'),
        #    'dense1/kernel': graph.get_tensor_by_name('online/dense1/kernel:0'),
        #    'layer_1/bias': graph.get_tensor_by_name('online/layer_1/bias:0'),
        #    'layer_1/kernel': graph.get_tensor_by_name('online/layer_1/kernel:0'),
        #    'layer_2/bias': graph.get_tensor_by_name('online/layer_2/bias:0'),
        #    'layer_2/kernel': graph.get_tensor_by_name('online/layer_2/kernel:0'),
        #    'layer_3/bias': graph.get_tensor_by_name('online/layer_3/bias:0'),
        #    'layer_3/kernel': graph.get_tensor_by_name('online/layer_3/kernel:0'),
        #    'dense1/bias': graph.get_tensor_by_name('online_1/dense1/bias:0'),
        #    'dense1/kernel': graph.get_tensor_by_name('online_1/dense1/kernel:0'),
        #    'layer_1/bias': graph.get_tensor_by_name('online_1/layer_1/bias:0'),
        #    'layer_1/kernel': graph.get_tensor_by_name('online_1/layer_1/kernel:0'),
        #    'layer_2/bias': graph.get_tensor_by_name('online_1/layer_2/bias:0'),
        #    'layer_2/kernel': graph.get_tensor_by_name('online_1/layer_2/kernel:0'),
        #    'layer_3/bias': graph.get_tensor_by_name('online_1/layer_3/bias:0'),
        #    'layer_3/kernel': graph.get_tensor_by_name('online_1/layer_3/kernel:0'),
        #    'dense1/bias': graph.get_tensor_by_name('online_2/dense1/bias:0'),
        #    'dense1/kernel': graph.get_tensor_by_name('online_2/dense1/kernel:0'),
        #    'layer_1/bias': graph.get_tensor_by_name('online_2/layer_1/bias:0'),
        #    'layer_1/kernel': graph.get_tensor_by_name('online_2/layer_1/kernel:0'),
        #    'layer_2/bias': graph.get_tensor_by_name('online_2/layer_2/bias:0'),
        #    'layer_2/kernel': graph.get_tensor_by_name('online_2/layer_2/kernel:0'),
        #    'layer_3/bias': graph.get_tensor_by_name('online_2/layer_3/bias:0'),
        #    'layer_3/kernel': graph.get_tensor_by_name('online_2/layer_3/kernel:0'),
        #    'dense1/bias': graph.get_tensor_by_name('target/dense1/bias:0'),
        #    'dense1/kernel': graph.get_tensor_by_name('target/dense1/kernel:0'),
        #    'layer_1/bias': graph.get_tensor_by_name('target/layer_1/bias:0'),
        #    'layer_1/kernel': graph.get_tensor_by_name('target/layer_1/kernel:0'),
        #    'layer_2/bias': graph.get_tensor_by_name('target/layer_2/bias:0'),
        #    'layer_2/kernel': graph.get_tensor_by_name('target/layer_2/kernel:0'),
        #    'layer_3/bias': graph.get_tensor_by_name('target/layer_3/bias:0'),
        #    'layer_3/kernel': graph.get_tensor_by_name('target/layer_3/kernel:0'),
        #    })
        #restore_saver.restore(sess, './model-images/model.ckpt')
        #print('model restored')

        weights = joblib.load('./ppo2_weights_266.joblib')
        #[<tf.Variable 'model/c1/w:0' shape=(8, 8, 4, 32) dtype=float32_ref>, <tf.Variable 'model/c1/b:0' shape=(1, 32, 1, 1) dtype=float32_ref>, <tf.Variable 'model/c2/w:0' shape=(4, 4, 32, 64) dtype=float32_ref>, <tf.Variable 'model/c2/b:0' shape=(1, 64, 1, 1) dtype=float32_ref>, <tf.Variable 'model/c3/w:0' shape=(3, 3, 64, 64) dtype=float32_ref>, <tf.Variable 'model/c3/b:0' shape=(1, 64, 1, 1) dtype=float32_ref>, <tf.Variable 'model/fc1/w:0' shape=(3136, 512) dtype=float32_ref>, <tf.Variable 'model/fc1/b:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'model/v/w:0' shape=(512, 1) dtype=float32_ref>, <tf.Variable 'model/v/b:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'model/pi/w:0' shape=(512, 7) dtype=float32_ref>, <tf.Variable 'model/pi/b:0' shape=(7,) dtype=float32_ref>]

        graph = tf.get_default_graph()
        for model in ['online', 'target']:
            tensor_names = [
                '{}/layer_1/conv2d/kernel:0',
                '{}/layer_1/conv2d/bias:0',
                '{}/layer_2/conv2d/kernel:0',
                '{}/layer_2/conv2d/bias:0',
                '{}/layer_3/conv2d/kernel:0',
                '{}/layer_3/conv2d/bias:0',
                #'{}/dense1/kernel:0',
                #'{}/dense1/bias:0'
            ]
            for i in range(len(tensor_names)):
                tensor_name = tensor_names[i].format(model)
                tensor = graph.get_tensor_by_name(tensor_name)
                weight = weights[i]
                if 'bias' in tensor_name:
                    weight = np.reshape(weight, tensor.get_shape())
                print('about to assign {} value with size {}'.format(
                    tensor_name, weights[i].shape))
                sess.run(tf.assign(tensor, weight))

        saver = tf.train.Saver()
        save_path = saver.save(sess, "./model/model.ckpt")
        print('Saved model')
        replay_buffer = PrioritizedReplayBuffer(100000, 0.5, 0.4, epsilon=0.1)

        #replay_buffer = pickle.load(gzip.open('./docker-build/model/replay_buffer.p.gz', 'rb'))
        #replay_buffer = pickle.load(open('./model/replay_buffer.p', 'rb'))

        total_steps = 50000000
        steps_per_env = 5000
        env.close()

        for i in range(int(total_steps / steps_per_env)):
            game, state = random.choice(envs)
            env = make_training_env(game, state, stack=False, scale_rew=False)
            env = prep_env(env)
            player = NStepPlayer(BatchedPlayer(env, dqn.online_net), 3)

            #dqn.train(num_steps=steps_per_env, # Make sure an exception arrives before we stop.
            #      player=player,
            #      replay_buffer=replay_buffer,
            #      optimize_op=optimize,
            #      train_interval=1,
            #      target_interval=8192,
            #      batch_size=32,
            #      min_buffer_size=20000)

            summary = train(
                dqn,
                num_steps=
                steps_per_env,  # Make sure an exception arrives before we stop.
                player=player,
                replay_buffer=replay_buffer,
                optimize_op=optimize,
                train_interval=4,
                target_interval=8192,
                batch_size=32,
                min_buffer_size=20000,
                summary_op=summary_op,
                handle_ep=lambda st, rew:
                (reward.assign(rew), steps.assign(st)),
                handle_step=lambda st, rew:
                (reward.assign(reward + rew), steps.assign(steps + st)))

            env.close()

            if summary:
                train_writer.add_summary(summary, i)
            else:
                print('No summary')

            save_path = saver.save(sess, "./model/model.ckpt")
            pickle.dump(replay_buffer, open("./model/replay_buffer.p", "wb"))
            print('Saved model')
コード例 #24
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--restore',
                        '-restore',
                        action='store_true',
                        help='restore from checkpoint file')
    parser.add_argument('--record',
                        '-record',
                        action='store_true',
                        help='record bk2 movies')
    args = parser.parse_args()
    """Run DQN until the environment throws an exception."""
    env = AllowBacktracking(
        make_env(stack=False, scale_rew=False, record=args.record))
    env = BatchedFrameStack(BatchedGymEnv([[env]]), num_images=4, concat=False)

    checkpoint_dir = os.path.join(os.getcwd(), 'results')
    results_dir = os.path.join(os.getcwd(), 'results',
                               time.strftime("%d-%m-%Y_%H-%M-%S"))
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    summary_writer = tf.summary.FileWriter(results_dir)

    # TODO
    # env = wrappers.Monitor(env, results_dir, force=True)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # pylint: disable=E1101
    with tf.Session(config=config) as sess:
        dqn = DQN(*rainbow_models(sess,
                                  env.action_space.n,
                                  gym_space_vectorizer(env.observation_space),
                                  min_val=-200,
                                  max_val=200))

        saver = tf.train.Saver()
        if args.restore:
            latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
            if latest_checkpoint:
                print("Loading model checkpoint {} ...\n".format(
                    latest_checkpoint))
                saver.restore(sess, latest_checkpoint)
            else:
                print("Checkpoint not found")

        player = NStepPlayer(BatchedPlayer(env, dqn.online_net), 3)
        optimize = dqn.optimize(learning_rate=1e-4)
        sess.run(tf.global_variables_initializer())

        reward_hist = []
        total_steps = 0

        # runs with every completed episode
        def _handle_ep(steps, rew):
            nonlocal total_steps
            total_steps += steps
            reward_hist.append(rew)

            summary_reward = tf.Summary()
            summary_reward.value.add(tag='global/reward', simple_value=rew)
            summary_writer.add_summary(summary_reward, global_step=total_steps)

            print('save model')
            saver.save(sess=sess,
                       save_path=checkpoint_dir + '/model',
                       global_step=total_steps)

            if len(reward_hist) == REWARD_HISTORY:
                print('%d steps: mean=%f' %
                      (total_steps, sum(reward_hist) / len(reward_hist)))
                summary_meanreward = tf.Summary()
                summary_meanreward.value.add(tag='global/mean_reward',
                                             simple_value=sum(reward_hist) /
                                             len(reward_hist))
                summary_writer.add_summary(summary_meanreward,
                                           global_step=total_steps)
                reward_hist.clear()

        dqn.train(
            num_steps=7000000,  # Make sure an exception arrives before we stop.
            player=player,
            replay_buffer=PrioritizedReplayBuffer(500000,
                                                  0.5,
                                                  0.4,
                                                  epsilon=0.1),
            optimize_op=optimize,
            train_interval=1,
            target_interval=8192,
            batch_size=32,
            min_buffer_size=20000,
            handle_ep=_handle_ep)