class Worker(object): """ An A3C worker thread. Runs episodes locally and updates global shared value and policy nets. Args: name: A unique name for this worker env: The Gym environment used by this worker policy_net: Instance of the globally shared policy net value_net: Instance of the globally shared value net global_counter: Iterator that holds the global step discount_factor: Reward discount factor summary_writer: A tf.train.SummaryWriter for Tensorboard summaries max_global_steps: If set, stop coordinator when global_counter > max_global_steps """ def __init__(self, name, env, policy_net, value_net, global_counter, discount_factor=0.99, summary_writer=None, max_global_steps=None): self.name = name self.discount_factor = discount_factor self.max_global_steps = max_global_steps self.global_step = tf.contrib.framework.get_global_step() self.global_policy_net = policy_net self.global_value_net = value_net self.global_counter = global_counter self.local_counter = itertools.count() self.summary_writer = summary_writer self.env = env # Create local policy/value nets that are not updated asynchronously with tf.variable_scope(name): if LSTM_POLICY: self.policy_net = LSTMPolicyEstimator(policy_net.num_outputs) else: self.policy_net = PolicyEstimator(policy_net.num_outputs) self.value_net = ValueEstimator(reuse=True) # Op to copy params from global policy/valuenets self.copy_params_op = make_copy_params_op( tf.contrib.slim.get_variables( scope="global", collection=tf.GraphKeys.TRAINABLE_VARIABLES), tf.contrib.slim.get_variables( scope=self.name, collection=tf.GraphKeys.TRAINABLE_VARIABLES)) self.vnet_train_op = make_train_op(self.value_net, self.global_value_net) self.pnet_train_op = make_train_op(self.policy_net, self.global_policy_net) self.state = None def run(self, sess, coord, t_max): with sess.as_default(), sess.graph.as_default(): # Initial state self.state = self.env.reset() try: while not coord.should_stop(): # Copy Parameters from the global networks sess.run(self.copy_params_op) # Collect some experience transitions, local_t, global_t = self.run_n_steps( t_max, sess) if self.max_global_steps is not None and global_t >= self.max_global_steps: tf.logging.info( "Reached global step {}. Stopping.".format( global_t)) coord.request_stop() return # Update the global networks self.update(transitions, sess) except tf.errors.CancelledError: return def run_n_steps(self, n, sess): transitions = [] if LSTM_POLICY: self.policy_net.reset_lstm_features() for _ in range(n): # Take a step if LSTM_POLICY: action_probs = self.policy_net.action_inference(self.state) else: action_probs = self.policy_net.action_prediction(self.state) # eps-greedy action action = np.random.choice(np.arange(len(action_probs)), p=action_probs) next_state, reward, done, _ = self.env.step(action) # Store transition transitions.append( Transition(state=self.state, action=action, reward=reward, next_state=next_state, done=done)) # Increase local and global counters local_t = next(self.local_counter) global_t = next(self.global_counter) if local_t % 1000 == 0: tf.logging.info("{}: local Step {}, global step {}".format( self.name, local_t, global_t)) if done: self.state = self.env.reset() ### reset features if LSTM_POLICY: self.policy_net.reset_lstm_features() break else: self.state = next_state return transitions, local_t, global_t def update(self, transitions, sess): """ Updates global policy and value networks based on collected experience Args: transitions: A list of experience transitions sess: A Tensorflow session """ # If we episode was not done we bootstrap the value from the last state reward = 0.0 if not transitions[-1].done: reward = self.value_net.predict_value(transitions[-1].next_state) if LSTM_POLICY: init_lstm_state = self.policy_net.get_init_features() # Accumulate minibatch exmaples states = [] policy_targets = [] value_targets = [] actions = [] features = [] for transition in transitions[::-1]: reward = transition.reward + self.discount_factor * reward policy_target = (reward - self.value_net.predict_value(transition.state)) # Accumulate updates states.append(transition.state) actions.append(transition.action) policy_targets.append(policy_target) value_targets.append(reward) if LSTM_POLICY: feed_dict = { self.policy_net.states: np.array(states), self.policy_net.targets: policy_targets, self.policy_net.actions: actions, self.policy_net.state_in[0]: np.array(init_lstm_state[0]), self.policy_net.state_in[1]: np.array(init_lstm_state[1]), self.value_net.states: np.array(states), self.value_net.targets: value_targets, } else: feed_dict = { self.policy_net.states: np.array(states), self.policy_net.targets: policy_targets, self.policy_net.actions: actions, self.value_net.states: np.array(states), self.value_net.targets: value_targets, } # Train the global estimators using local gradients global_step, pnet_loss, vnet_loss, _, _, pnet_summaries, vnet_summaries = sess.run( [ self.global_step, self.policy_net.loss, self.value_net.loss, self.pnet_train_op, self.vnet_train_op, self.policy_net.summaries, self.value_net.summaries ], feed_dict) # Write summaries if self.summary_writer is not None and global_step % SUMMARY_EACH_STEPS == 0: self.summary_writer.add_summary(pnet_summaries, global_step) self.summary_writer.add_summary(vnet_summaries, global_step) self.summary_writer.flush() return pnet_loss, vnet_loss, pnet_summaries, vnet_summaries
class PolicyMonitor(object): """ Helps evaluating a policy by running an episode in an environment, saving a video, and plotting summaries to Tensorboard. Args: env: environment to run in policy_net: A policy estimator summary_writer: a tf.train.SummaryWriter used to write Tensorboard summaries """ def __init__(self, env, policy_net, summary_writer, saver=None): self.global_policy_net = policy_net self.summary_writer = summary_writer self.saver = saver self.env = env # Correct the path self.checkpoint_path = os.path.abspath( os.path.join(summary_writer.get_logdir(), "../checkpoints/model")) print('[PM] checkpoint_path: {}'.format(self.checkpoint_path)) # Local policy net with tf.variable_scope("policy_eval"): if LSTM_POLICY: self.policy_net = LSTMPolicyEstimator(policy_net.num_outputs) else: self.policy_net = PolicyEstimator(policy_net.num_outputs) # Op to copy params from global policy/value net parameters self.copy_params_op = make_copy_params_op( tf.contrib.slim.get_variables( scope="global", collection=tf.GraphKeys.TRAINABLE_VARIABLES), tf.contrib.slim.get_variables( scope="policy_eval", collection=tf.GraphKeys.TRAINABLE_VARIABLES)) def eval_once(self, sess): with sess.as_default(), sess.graph.as_default(): # Copy params to local model global_step, _ = sess.run( [tf.contrib.framework.get_global_step(), self.copy_params_op]) # Run an episode done = False state = self.env.reset() if LSTM_POLICY: self.policy_net.reset_lstm_features() total_reward = 0.0 episode_length = 0 while not done: if LSTM_POLICY: action_probs = self.policy_net.action_inference(state) else: action_probs = self.policy_net.action_prediction(state) action = np.random.choice(np.arange(len(action_probs)), p=action_probs) next_state, reward, done, _ = self.env.step(action) # next_state = atari_helpers.atari_make_next_state(state, self.sp.process(next_state)) total_reward += reward episode_length += 1 state = next_state # Add summaries episode_summary = tf.Summary() episode_summary.value.add(simple_value=total_reward, tag="eval/total_reward") episode_summary.value.add(simple_value=episode_length, tag="eval/episode_length") self.summary_writer.add_summary(episode_summary, global_step) self.summary_writer.flush() if self.saver is not None: self.saver.save(sess, self.checkpoint_path) tf.logging.info( "Eval results at step {}: total_reward {}, episode_length {}". format(global_step, total_reward, episode_length)) return total_reward, episode_length def continuous_eval(self, eval_every, sess, coord): """ Continuously evaluates the policy every [eval_every] seconds. """ try: while not coord.should_stop(): self.eval_once(sess) # Sleep until next evaluation cycle time.sleep(eval_every) except tf.errors.CancelledError: return