Exemple #1
0
def evaluation(session, graph_ops, saver):
    saver.restore(session, CHECKPOINT_NAME)
    print "Restored model weights from ", CHECKPOINT_NAME
    monitor_env = gym.make(GAME)
    monitor_env.monitor.start('/tmp/' + EXPERIMENT_NAME + "/eval")

    # Unpack graph ops
    s, a_t, R_t, minimize, p_network, v_network = graph_ops

    # Wrap env with AtariEnvironment helper class
    env = AtariEnvironment(gym_env=monitor_env,
                           resized_width=RESIZED_WIDTH,
                           resized_height=RESIZED_HEIGHT,
                           agent_history_length=AGENT_HISTORY_LENGTH)

    for i_episode in xrange(100):
        s_t = env.get_initial_state()
        ep_reward = 0
        terminal = False
        while not terminal:
            monitor_env.render()
            # Forward the deep q network, get Q(s,a) values
            probs = p_network.eval(session=session, feed_dict={s: [s_t]})[0]
            action_index = sample_policy_action(ACTIONS, probs)
            s_t1, r_t, terminal, info = env.step(action_index)
            s_t = s_t1
            ep_reward += r_t
        print ep_reward
    monitor_env.monitor.close()
Exemple #2
0
def evaluation(session, graph_ops, saver):
    saver.restore(session, FLAGS.checkpoint_path)
    print "Restored model weights from ", FLAGS.checkpoint_path
    monitor_env = gym.make(FLAGS.game)
    monitor_env.monitor.start(FLAGS.eval_dir + "/" + FLAGS.experiment +
                              "/eval")

    # Unpack graph ops
    s = graph_ops["s"]
    q_values = graph_ops["q_values"]

    # Wrap env with AtariEnvironment helper class
    env = AtariEnvironment(gym_env=monitor_env,
                           resized_width=FLAGS.resized_width,
                           resized_height=FLAGS.resized_height,
                           agent_history_length=FLAGS.agent_history_length)

    for i_episode in xrange(FLAGS.num_eval_episodes):
        s_t = env.get_initial_state()
        ep_reward = 0
        terminal = False
        while not terminal:
            monitor_env.render()
            readout_t = q_values.eval(session=session, feed_dict={s: [s_t]})
            action_index = np.argmax(readout_t)
            s_t1, r_t, terminal, info = env.step(action_index)
            s_t = s_t1
            ep_reward += r_t
        print ep_reward
    monitor_env.monitor.close()
Exemple #3
0
def evaluation(session, graph_ops, saver):
    saver.restore(session, CHECKPOINT_NAME)
    print "Restored model weights from ", CHECKPOINT_NAME
    monitor_env = gym.make(GAME)
    monitor_env.monitor.start('/tmp/'+EXPERIMENT_NAME+"/eval")

    # Unpack graph ops
    s, a_t, R_t, learning_rate, minimize, p_network, v_network = graph_ops

    # Wrap env with AtariEnvironment helper class
    env = AtariEnvironment(gym_env=monitor_env, resized_width=RESIZED_WIDTH, resized_height=RESIZED_HEIGHT, agent_history_length=AGENT_HISTORY_LENGTH)

    for i_episode in xrange(100):
        s_t = env.get_initial_state()
        ep_reward = 0
        terminal = False
        while not terminal:
            monitor_env.render()
            # Forward the deep q network, get Q(s,a) values
            probs = p_network.eval(session = session, feed_dict = {s : [s_t]})[0]
            action_index = sample_policy_action(ACTIONS, probs)
            s_t1, r_t, terminal, info = env.step(action_index)
            s_t = s_t1
            ep_reward += r_t
        print ep_reward
    monitor_env.monitor.close()
Exemple #4
0
def evaluation(session, graph_ops, saver):
    saver.restore(session, FLAGS.checkpoint_path)
    print "Restored model weights from ", FLAGS.checkpoint_path
    monitor_env = gym.make(FLAGS.game)
    monitor_env.monitor.start(FLAGS.eval_dir+"/"+FLAGS.experiment+"/eval")

    # Unpack graph ops
    s = graph_ops["s"]
    q_values = graph_ops["q_values"]

    # Wrap env with AtariEnvironment helper class
    env = AtariEnvironment(gym_env=monitor_env, resized_width=FLAGS.resized_width, resized_height=FLAGS.resized_height, agent_history_length=FLAGS.agent_history_length)

    for i_episode in xrange(FLAGS.num_eval_episodes):
        s_t = env.get_initial_state()
        ep_reward = 0
        terminal = False
        while not terminal:
            monitor_env.render()
            readout_t = q_values.eval(session = session, feed_dict = {s : [s_t]})
            action_index = np.argmax(readout_t)
            s_t1, r_t, terminal, info = env.step(action_index)
            s_t = s_t1
            ep_reward += r_t
        print ep_reward
    monitor_env.monitor.close()
Exemple #5
0
def evaluation(session, graph_ops, saver):
    saver.restore(session, FLAGS.checkpoint_path)
    print("Restored model weights from ", FLAGS.checkpoint_path)
    monitor_env = gym.make(FLAGS.game)
    gym.wrappers.Monitor(monitor_env,
                         FLAGS.eval_dir + "/" + FLAGS.experiment + "/eval")

    # Unpack graph ops
    s = graph_ops["s"]
    q_values = graph_ops["q_values"]

    # Wrap env with AtariEnvironment helper class
    if env_type in {'atari'}:
        env = AtariEnvironment(gym_env=monitor_env,
                               resized_width=FLAGS.resized_width,
                               resized_height=FLAGS.resized_height,
                               agent_history_length=FLAGS.agent_history_length)
    else:
        env = CustomEnvironment(
            gym_env=monitor_env,
            input_size=FLAGS.input_size,
            agent_history_length=FLAGS.agent_history_length,
            extra_args={
                'init_with_args': FLAGS.init_with_args,
                'setting_file_path': FLAGS.setting_file_path
            })

    for i_episode in range(FLAGS.num_eval_episodes):
        s_t = env.get_initial_state()
        ep_reward = 0
        terminal = False
        while not terminal:
            monitor_env.render()
            readout_t = q_values.eval(session=session, feed_dict={s: [s_t]})
            action_index = np.argmax(readout_t)
            print("action", action_index)
            s_t1, r_t, terminal, info = env.step(action_index)
            s_t = s_t1
            ep_reward += r_t
        print(ep_reward)
    monitor_env.monitor.close()
Exemple #6
0
def actor_learner_thread(num, env, session, graph_ops, summary_ops, saver):
    # We use global shared counter T, and TMAX constant
    global TMAX, T

    # Unpack graph ops
    s, a, R, minimize, p_network, v_network = graph_ops

    # Unpack tensorboard summary stuff
    r_summary_placeholder, update_ep_reward, val_summary_placeholder, update_ep_val, summary_op = summary_ops

    # Wrap env with AtariEnvironment helper class
    env = AtariEnvironment(gym_env=env,
                           resized_width=RESIZED_WIDTH,
                           resized_height=RESIZED_HEIGHT,
                           agent_history_length=AGENT_HISTORY_LENGTH)

    time.sleep(5 * num)

    # Set up per-episode counters
    ep_reward = 0
    ep_avg_v = 0
    v_steps = 0
    ep_t = 0

    probs_summary_t = 0

    s_t = env.get_initial_state()
    terminal = False

    while T < TMAX:
        s_batch = []
        past_rewards = []
        a_batch = []

        t = 0
        t_start = t

        while not (terminal or ((t - t_start) == t_max)):
            # Perform action a_t according to policy pi(a_t | s_t)
            probs = session.run(p_network, feed_dict={s: [s_t]})[0]
            action_index = sample_policy_action(ACTIONS, probs)
            a_t = np.zeros([ACTIONS])
            a_t[action_index] = 1

            if probs_summary_t % 100 == 0:
                print "P, ", np.max(probs), "V ", session.run(
                    v_network, feed_dict={s: [s_t]})[0][0]

            s_batch.append(s_t)
            a_batch.append(a_t)

            s_t1, r_t, terminal, info = env.step(action_index)
            ep_reward += r_t

            r_t = np.clip(r_t, -1, 1)
            past_rewards.append(r_t)

            t += 1
            T += 1
            ep_t += 1
            probs_summary_t += 1

            s_t = s_t1

        if terminal:
            R_t = 0
        else:
            R_t = session.run(v_network,
                              feed_dict={s: [s_t]
                                         })[0][0]  # Bootstrap from last state

        R_batch = np.zeros(t)
        for i in reversed(range(t_start, t)):
            R_t = past_rewards[i] + GAMMA * R_t
            R_batch[i] = R_t

        session.run(minimize, feed_dict={R: R_batch, a: a_batch, s: s_batch})

        # Save progress every 5000 iterations
        if T % CHECKPOINT_INTERVAL == 0:
            saver.save(session, CHECKPOINT_SAVE_PATH, global_step=T)

        if terminal:
            # Episode ended, collect stats and reset game
            session.run(update_ep_reward,
                        feed_dict={r_summary_placeholder: ep_reward})
            print "THREAD:", num, "/ TIME", T, "/ REWARD", ep_reward
            s_t = env.get_initial_state()
            terminal = False
            # Reset per-episode counters
            ep_reward = 0
            ep_t = 0
Exemple #7
0
def actor_learner_thread(thread_id, env, session, graph_ops, num_actions,
                         summary_ops, saver):
    """
    Actor-learner thread implementing asynchronous one-step Q-learning, as specified
    in algorithm 1 here: http://arxiv.org/pdf/1602.01783v1.pdf.
    """
    global TMAX, T

    # Unpack graph ops
    s = graph_ops["s"]
    q_values = graph_ops["q_values"]
    st = graph_ops["st"]
    target_q_values = graph_ops["target_q_values"]
    reset_target_network_params = graph_ops["reset_target_network_params"]
    a = graph_ops["a"]
    y = graph_ops["y"]
    grad_update = graph_ops["grad_update"]

    summary_placeholders, update_ops, summary_op = summary_ops

    # Wrap env with AtariEnvironment helper class
    env = AtariEnvironment(gym_env=env,
                           resized_width=FLAGS.resized_width,
                           resized_height=FLAGS.resized_height,
                           agent_history_length=FLAGS.agent_history_length)

    # Initialize network gradients
    s_batch = []
    a_batch = []
    y_batch = []

    final_epsilon = sample_final_epsilon()
    initial_epsilon = 1.0
    epsilon = 1.0

    print "Starting thread ", thread_id, "with final epsilon ", final_epsilon

    time.sleep(3 * thread_id)
    t = 0
    while T < TMAX:
        # Get initial game observation
        s_t = env.get_initial_state()
        terminal = False

        # Set up per-episode counters
        ep_reward = 0
        episode_ave_max_q = 0
        ep_t = 0

        while True:
            # Forward the deep q network, get Q(s,a) values
            readout_t = q_values.eval(session=session, feed_dict={s: [s_t]})

            # Choose next action based on e-greedy policy
            a_t = np.zeros([num_actions])
            action_index = 0
            if random.random() <= epsilon:
                action_index = random.randrange(num_actions)
            else:
                action_index = np.argmax(readout_t)
            a_t[action_index] = 1

            # Scale down epsilon
            if epsilon > final_epsilon:
                epsilon -= (initial_epsilon -
                            final_epsilon) / FLAGS.anneal_epsilon_timesteps

            # Gym excecutes action in game environment on behalf of actor-learner
            s_t1, r_t, terminal, info = env.step(action_index)

            # Accumulate gradients
            readout_j1 = target_q_values.eval(session=session,
                                              feed_dict={st: [s_t1]})
            clipped_r_t = np.clip(r_t, -1, 1)
            if terminal:
                y_batch.append(clipped_r_t)
            else:
                y_batch.append(clipped_r_t + FLAGS.gamma * np.max(readout_j1))

            a_batch.append(a_t)
            s_batch.append(s_t)

            # Update the state and counters
            s_t = s_t1
            T += 1
            t += 1

            ep_t += 1
            ep_reward += r_t
            episode_ave_max_q += np.max(readout_t)

            # Optionally update target network
            if T % FLAGS.target_network_update_frequency == 0:
                session.run(reset_target_network_params)

            # Optionally update online network
            if t % FLAGS.network_update_frequency == 0 or terminal:
                if s_batch:
                    session.run(grad_update,
                                feed_dict={
                                    y: y_batch,
                                    a: a_batch,
                                    s: s_batch
                                })
                # Clear gradients
                s_batch = []
                a_batch = []
                y_batch = []

            # Save model progress
            if t % FLAGS.checkpoint_interval == 0:
                saver.save(session,
                           FLAGS.checkpoint_dir + "/" + FLAGS.experiment +
                           ".ckpt",
                           global_step=t)

            # Print end of episode stats
            if terminal:
                stats = [ep_reward, episode_ave_max_q / float(ep_t), epsilon]
                for i in range(len(stats)):
                    session.run(
                        update_ops[i],
                        feed_dict={summary_placeholders[i]: float(stats[i])})
                print "THREAD:", thread_id, "/ TIME", T, "/ TIMESTEP", t, "/ EPSILON", epsilon, "/ REWARD", ep_reward, "/ Q_MAX %.4f" % (
                    episode_ave_max_q /
                    float(ep_t)), "/ EPSILON PROGRESS", t / float(
                        FLAGS.anneal_epsilon_timesteps)
                break
Exemple #8
0
def actor_learner_thread(num, env, session, graph_ops, summary_ops, saver):
    # We use global shared counter T, and TMAX constant
    global TMAX, T

    # Unpack graph ops
    s, a, R, minimize, p_network, v_network = graph_ops

    # Unpack tensorboard summary stuff
    r_summary_placeholder, update_ep_reward, val_summary_placeholder, update_ep_val, summary_op = summary_ops

    # Wrap env with AtariEnvironment helper class
    env = AtariEnvironment(gym_env=env, resized_width=RESIZED_WIDTH, resized_height=RESIZED_HEIGHT, agent_history_length=AGENT_HISTORY_LENGTH)

    time.sleep(5*num)

    # Set up per-episode counters
    ep_reward = 0
    ep_avg_v = 0
    v_steps = 0
    ep_t = 0

    probs_summary_t = 0

    s_t = env.get_initial_state()
    terminal = False

    while T < TMAX:
        s_batch = []
        past_rewards = []
        a_batch = []

        t = 0
        t_start = t

        while not (terminal or ((t - t_start)  == t_max)):
            # Perform action a_t according to policy pi(a_t | s_t)
            probs = session.run(p_network, feed_dict={s: [s_t]})[0]
            action_index = sample_policy_action(ACTIONS, probs)
            a_t = np.zeros([ACTIONS])
            a_t[action_index] = 1

            if probs_summary_t % 100 == 0:
                print "P, ", np.max(probs), "V ", session.run(v_network, feed_dict={s: [s_t]})[0][0]

            s_batch.append(s_t)
            a_batch.append(a_t)

            s_t1, r_t, terminal, info = env.step(action_index)
            ep_reward += r_t

            r_t = np.clip(r_t, -1, 1)
            past_rewards.append(r_t)

            t += 1
            T += 1
            ep_t += 1
            probs_summary_t += 1
            
            s_t = s_t1

        if terminal:
            R_t = 0
        else:
            R_t = session.run(v_network, feed_dict={s: [s_t]})[0][0] # Bootstrap from last state

        R_batch = np.zeros(t)
        for i in reversed(range(t_start, t)):
            R_t = past_rewards[i] + GAMMA * R_t
            R_batch[i] = R_t

        session.run(minimize, feed_dict={R : R_batch,
                                         a : a_batch,
                                         s : s_batch})
        
        # Save progress every 5000 iterations
        if T % CHECKPOINT_INTERVAL == 0:
            saver.save(session, CHECKPOINT_SAVE_PATH, global_step = T)

        if terminal:
            # Episode ended, collect stats and reset game
            session.run(update_ep_reward, feed_dict={r_summary_placeholder: ep_reward})
            print "THREAD:", num, "/ TIME", T, "/ REWARD", ep_reward
            s_t = env.get_initial_state()
            terminal = False
            # Reset per-episode counters
            ep_reward = 0
            ep_t = 0
Exemple #9
0
def actor_learner_thread(thread_id, env, session, graph_ops, num_actions, summary_ops, saver):
    """
    Actor-learner thread implementing asynchronous one-step Q-learning, as specified
    in algorithm 1 here: http://arxiv.org/pdf/1602.01783v1.pdf.
    """
    global TMAX, T

    # Unpack graph ops
    s = graph_ops["s"]
    q_values = graph_ops["q_values"]
    st = graph_ops["st"]
    target_q_values = graph_ops["target_q_values"]
    reset_target_network_params = graph_ops["reset_target_network_params"]
    a = graph_ops["a"]
    y = graph_ops["y"]
    grad_update = graph_ops["grad_update"]

    summary_placeholders, update_ops, summary_op = summary_ops

    # Wrap env with AtariEnvironment helper class
    env = AtariEnvironment(gym_env=env, resized_width=FLAGS.resized_width, resized_height=FLAGS.resized_height, agent_history_length=FLAGS.agent_history_length)

    # Initialize network gradients
    s_batch = []
    a_batch = []
    y_batch = []

    final_epsilon = sample_final_epsilon()
    initial_epsilon = 1.0
    epsilon = 1.0

    print "Starting thread ", thread_id, "with final epsilon ", final_epsilon

    time.sleep(3*thread_id)
    t = 0
    while T < TMAX:
        # Get initial game observation
        s_t = env.get_initial_state()
        terminal = False

        # Set up per-episode counters
        ep_reward = 0
        episode_ave_max_q = 0
        ep_t = 0

        while True:
            # Forward the deep q network, get Q(s,a) values
            readout_t = q_values.eval(session = session, feed_dict = {s : [s_t]})
            
            # Choose next action based on e-greedy policy
            a_t = np.zeros([num_actions])
            action_index = 0
            if random.random() <= epsilon:
                action_index = random.randrange(num_actions)
            else:
                action_index = np.argmax(readout_t)
            a_t[action_index] = 1

            # Scale down epsilon
            if epsilon > final_epsilon:
                epsilon -= (initial_epsilon - final_epsilon) / FLAGS.anneal_epsilon_timesteps
    
            # Gym excecutes action in game environment on behalf of actor-learner
            s_t1, r_t, terminal, info = env.step(action_index)

            # Accumulate gradients
            readout_j1 = target_q_values.eval(session = session, feed_dict = {st : [s_t1]})
            clipped_r_t = np.clip(r_t, -1, 1)
            if terminal:
                y_batch.append(clipped_r_t)
            else:
                y_batch.append(clipped_r_t + FLAGS.gamma * np.max(readout_j1))
    
            a_batch.append(a_t)
            s_batch.append(s_t)
    
            # Update the state and counters
            s_t = s_t1
            T += 1
            t += 1

            ep_t += 1
            ep_reward += r_t
            episode_ave_max_q += np.max(readout_t)

            # Optionally update target network
            if T % FLAGS.target_network_update_frequency == 0:
                session.run(reset_target_network_params)
    
            # Optionally update online network
            if t % FLAGS.network_update_frequency == 0 or terminal:
                if s_batch:
                    session.run(grad_update, feed_dict = {y : y_batch,
                                                          a : a_batch,
                                                          s : s_batch})
                # Clear gradients
                s_batch = []
                a_batch = []
                y_batch = []
    
            # Save model progress
            if t % FLAGS.checkpoint_interval == 0:
                saver.save(session, FLAGS.checkpoint_dir+"/"+FLAGS.experiment+".ckpt", global_step = t)
    
            # Print end of episode stats
            if terminal:
                stats = [ep_reward, episode_ave_max_q/float(ep_t), epsilon]
                for i in range(len(stats)):
                    session.run(update_ops[i], feed_dict={summary_placeholders[i]:float(stats[i])})
                print "THREAD:", thread_id, "/ TIME", T, "/ TIMESTEP", t, "/ EPSILON", epsilon, "/ REWARD", ep_reward, "/ Q_MAX %.4f" % (episode_ave_max_q/float(ep_t)), "/ EPSILON PROGRESS", t/float(FLAGS.anneal_epsilon_timesteps)
                break
Exemple #10
0
class Agent(Process):
    def __init__(self, id, prediction_q, training_q, episode_log_q):
        super(Agent, self).__init__(name="Agent_{}".format(id))
        self.id = id
        self.prediction_q = prediction_q
        self.training_q = training_q
        self.episode_log_q = episode_log_q

        gym_env = gym.make(FLAGS.game)
        gym_env.seed(FLAGS.seed)

        self.env = AtariEnvironment(gym_env=gym_env, resized_width=FLAGS.resized_width,
                                    resized_height=FLAGS.resized_height,
                                    agent_history_length=FLAGS.agent_history_length)

        self.nb_actions = len(self.env.gym_actions)
        self.wait_q = Queue(maxsize=1)
        self.stop = Value('i', 0)

    def run(self):
        time.sleep(np.random.rand())

        while not self.stop.value:
            if FLAGS.verbose:
                print("Agent_{} started a new episode".format(self.id))
            # total_reward = 0
            # total_length = 0
            for episode_buffer, episode_reward, episode_length in self.run_episode_generator():
                if FLAGS.verbose:
                    print("Agent_{} puts a new episode in the training queue".format(self.id))
                self.training_q.put(episode_buffer)
            print("Agent_{} fished an episode and logs the result in the logs queue".format(self.id))
            self.episode_log_q.put([datetime.now(), episode_reward, episode_length])

    def run_episode_generator(self):
        s, _ = self.env.get_initial_state()

        d = False
        episode_buffer = []
        episode_reward = 0
        episode_step_count = 0

        while not d:
            self.prediction_q.put((self.id, s))
            pi, v = self.wait_q.get()
            a = np.random.choice(pi[0], p=pi[0])
            a = np.argmax(pi == a)

            s1, r, d, info = self.env.step(a)

            r = np.clip(r, -1, 1)

            episode_buffer.append([s, a, pi, r, s1, d, v[0, 0]])
            episode_reward += r
            episode_step_count += 1
            s = s1

            if len(episode_buffer) == FLAGS.max_episode_buffer_size and not d:
                self.prediction_q.put((self.id, s))
                pi, v1 = self.wait_q.get()
                updated_episode_buffer = self.get_training_data(episode_buffer, v1)
                yield updated_episode_buffer, episode_reward, episode_step_count
            if d:
                break

        if len(episode_buffer) != 0:
            updated_episode_buffer = self.get_training_data(episode_buffer, 0)
            yield updated_episode_buffer, episode_reward, episode_step_count

    def discount(self, x):
        return lfilter([1], [1, -FLAGS.gamma], x[::-1], axis=0)[::-1]

    def get_training_data(self, rollout, bootstrap_value):
        rollout = np.array(rollout)
        observations = rollout[:, 0]
        actions = rollout[:, 1]
        pis = rollout[:, 2]
        rewards = rollout[:, 3]
        next_observations = rollout[:, 4]
        values = rollout[:, 5]

        rewards_plus = np.asarray(rewards.tolist() + [bootstrap_value])
        discounted_rewards = self.discount(rewards_plus, FLAGS.gamma)[:-1]
        value_plus = np.asarray(values.tolist() + [bootstrap_value])
        policy_target = discounted_rewards - value_plus[:-1]

        rollout.extend([discounted_rewards])
class Agent(Process):
    def __init__(self, id, prediction_q, training_q, episode_log_q):
        super(Agent, self).__init__(name="Agent_{}".format(id))
        self.id = id
        self.prediction_q = prediction_q
        self.training_q = training_q
        self.episode_log_q = episode_log_q

        gym_env = gym.make(FLAGS.game)
        gym_env.seed(FLAGS.seed)

        self.env = AtariEnvironment(
            gym_env=gym_env,
            resized_width=FLAGS.resized_width,
            resized_height=FLAGS.resized_height,
            agent_history_length=FLAGS.agent_history_length)

        self.nb_actions = len(self.env.gym_actions)
        self.wait_q = Queue(maxsize=1)
        self.stop = Value('i', 0)

    def run(self):
        time.sleep(np.random.rand())

        while not self.stop.value:
            if FLAGS.verbose:
                print("Agent_{} started a new episode".format(self.id))
            # total_reward = 0
            # total_length = 0
            for episode_buffer, episode_reward, episode_length in self.run_episode_generator(
            ):
                if FLAGS.verbose:
                    print("Agent_{} puts a new episode in the training queue".
                          format(self.id))
                self.training_q.put(episode_buffer)
            print(
                "Agent_{} fished an episode and logs the result in the logs queue"
                .format(self.id))
            self.episode_log_q.put(
                [datetime.now(), episode_reward, episode_length])

    def run_episode_generator(self):
        s, _ = self.env.get_initial_state()

        d = False
        episode_buffer = []
        episode_reward = 0
        episode_step_count = 0

        while not d:
            self.prediction_q.put((self.id, s))
            pi, v = self.wait_q.get()
            a = np.random.choice(pi[0], p=pi[0])
            a = np.argmax(pi == a)

            s1, r, d, info = self.env.step(a)

            r = np.clip(r, -1, 1)

            episode_buffer.append([s, a, pi, r, s1, d, v[0, 0]])
            episode_reward += r
            episode_step_count += 1
            s = s1

            if len(episode_buffer) == FLAGS.max_episode_buffer_size and not d:
                self.prediction_q.put((self.id, s))
                pi, v1 = self.wait_q.get()
                updated_episode_buffer = self.get_training_data(
                    episode_buffer, v1)
                yield updated_episode_buffer, episode_reward, episode_step_count
            if d:
                break

        if len(episode_buffer) != 0:
            updated_episode_buffer = self.get_training_data(episode_buffer, 0)
            yield updated_episode_buffer, episode_reward, episode_step_count

    def discount(self, x):
        return lfilter([1], [1, -FLAGS.gamma], x[::-1], axis=0)[::-1]

    def get_training_data(self, rollout, bootstrap_value):
        rollout = np.array(rollout)
        observations = rollout[:, 0]
        actions = rollout[:, 1]
        pis = rollout[:, 2]
        rewards = rollout[:, 3]
        next_observations = rollout[:, 4]
        values = rollout[:, 5]

        rewards_plus = np.asarray(rewards.tolist() + [bootstrap_value])
        discounted_rewards = self.discount(rewards_plus, FLAGS.gamma)[:-1]
        value_plus = np.asarray(values.tolist() + [bootstrap_value])
        policy_target = discounted_rewards - value_plus[:-1]

        rollout.extend([discounted_rewards])