예제 #1
0
class Worker(object):
    """
    An A3C worker thread. Runs episodes locally and updates global shared value and policy nets.

    Args:
        name: A unique name for this worker
        env: The Gym environment used by this worker
        policy_net: Instance of the globally shared policy net
        value_net: Instance of the globally shared value net
        global_counter: Iterator that holds the global step
        discount_factor: Reward discount factor
        summary_writer: A tf.train.SummaryWriter for Tensorboard summaries
        max_global_steps: If set, stop coordinator when global_counter > max_global_steps
    """
    def __init__(self,
                 name,
                 env,
                 policy_net,
                 value_net,
                 global_counter,
                 discount_factor=0.99,
                 summary_writer=None,
                 max_global_steps=None):
        self.name = name
        self.discount_factor = discount_factor
        self.max_global_steps = max_global_steps
        self.global_step = tf.contrib.framework.get_global_step()
        self.global_policy_net = policy_net
        self.global_value_net = value_net
        self.global_counter = global_counter
        self.local_counter = itertools.count()
        self.summary_writer = summary_writer
        self.env = env

        # Create local policy/value nets that are not updated asynchronously
        with tf.variable_scope(name):
            if LSTM_POLICY:
                self.policy_net = LSTMPolicyEstimator(policy_net.num_outputs)
            else:
                self.policy_net = PolicyEstimator(policy_net.num_outputs)
            self.value_net = ValueEstimator(reuse=True)

        # Op to copy params from global policy/valuenets
        self.copy_params_op = make_copy_params_op(
            tf.contrib.slim.get_variables(
                scope="global", collection=tf.GraphKeys.TRAINABLE_VARIABLES),
            tf.contrib.slim.get_variables(
                scope=self.name, collection=tf.GraphKeys.TRAINABLE_VARIABLES))

        self.vnet_train_op = make_train_op(self.value_net,
                                           self.global_value_net)
        self.pnet_train_op = make_train_op(self.policy_net,
                                           self.global_policy_net)

        self.state = None

    def run(self, sess, coord, t_max):
        with sess.as_default(), sess.graph.as_default():
            # Initial state
            self.state = self.env.reset()
            try:
                while not coord.should_stop():
                    # Copy Parameters from the global networks
                    sess.run(self.copy_params_op)

                    # Collect some experience
                    transitions, local_t, global_t = self.run_n_steps(
                        t_max, sess)

                    if self.max_global_steps is not None and global_t >= self.max_global_steps:
                        tf.logging.info(
                            "Reached global step {}. Stopping.".format(
                                global_t))
                        coord.request_stop()
                        return

                    # Update the global networks
                    self.update(transitions, sess)

            except tf.errors.CancelledError:
                return

    def run_n_steps(self, n, sess):
        transitions = []
        if LSTM_POLICY:
            self.policy_net.reset_lstm_features()
        for _ in range(n):
            # Take a step
            if LSTM_POLICY:
                action_probs = self.policy_net.action_inference(self.state)
            else:
                action_probs = self.policy_net.action_prediction(self.state)

            # eps-greedy action
            action = np.random.choice(np.arange(len(action_probs)),
                                      p=action_probs)
            next_state, reward, done, _ = self.env.step(action)

            # Store transition
            transitions.append(
                Transition(state=self.state,
                           action=action,
                           reward=reward,
                           next_state=next_state,
                           done=done))

            # Increase local and global counters
            local_t = next(self.local_counter)
            global_t = next(self.global_counter)

            if local_t % 1000 == 0:
                tf.logging.info("{}: local Step {}, global step {}".format(
                    self.name, local_t, global_t))

            if done:
                self.state = self.env.reset()
                ### reset features
                if LSTM_POLICY:
                    self.policy_net.reset_lstm_features()
                break
            else:
                self.state = next_state
        return transitions, local_t, global_t

    def update(self, transitions, sess):
        """
        Updates global policy and value networks based on collected experience

        Args:
            transitions: A list of experience transitions
            sess: A Tensorflow session
        """

        # If we episode was not done we bootstrap the value from the last state
        reward = 0.0
        if not transitions[-1].done:
            reward = self.value_net.predict_value(transitions[-1].next_state)

        if LSTM_POLICY:
            init_lstm_state = self.policy_net.get_init_features()

        # Accumulate minibatch exmaples
        states = []
        policy_targets = []
        value_targets = []
        actions = []
        features = []

        for transition in transitions[::-1]:
            reward = transition.reward + self.discount_factor * reward
            policy_target = (reward -
                             self.value_net.predict_value(transition.state))
            # Accumulate updates
            states.append(transition.state)
            actions.append(transition.action)
            policy_targets.append(policy_target)
            value_targets.append(reward)

        if LSTM_POLICY:
            feed_dict = {
                self.policy_net.states: np.array(states),
                self.policy_net.targets: policy_targets,
                self.policy_net.actions: actions,
                self.policy_net.state_in[0]: np.array(init_lstm_state[0]),
                self.policy_net.state_in[1]: np.array(init_lstm_state[1]),
                self.value_net.states: np.array(states),
                self.value_net.targets: value_targets,
            }
        else:
            feed_dict = {
                self.policy_net.states: np.array(states),
                self.policy_net.targets: policy_targets,
                self.policy_net.actions: actions,
                self.value_net.states: np.array(states),
                self.value_net.targets: value_targets,
            }

        # Train the global estimators using local gradients
        global_step, pnet_loss, vnet_loss, _, _, pnet_summaries, vnet_summaries = sess.run(
            [
                self.global_step, self.policy_net.loss, self.value_net.loss,
                self.pnet_train_op, self.vnet_train_op,
                self.policy_net.summaries, self.value_net.summaries
            ], feed_dict)

        # Write summaries
        if self.summary_writer is not None and global_step % SUMMARY_EACH_STEPS == 0:
            self.summary_writer.add_summary(pnet_summaries, global_step)
            self.summary_writer.add_summary(vnet_summaries, global_step)
            self.summary_writer.flush()

        return pnet_loss, vnet_loss, pnet_summaries, vnet_summaries
예제 #2
0
class PolicyMonitor(object):
    """
    Helps evaluating a policy by running an episode in an environment,
    saving a video, and plotting summaries to Tensorboard.

    Args:
        env: environment to run in
        policy_net: A policy estimator
        summary_writer: a tf.train.SummaryWriter used to write Tensorboard summaries
    """
    def __init__(self, env, policy_net, summary_writer, saver=None):

        self.global_policy_net = policy_net
        self.summary_writer = summary_writer
        self.saver = saver
        self.env = env

        # Correct the path
        self.checkpoint_path = os.path.abspath(
            os.path.join(summary_writer.get_logdir(), "../checkpoints/model"))
        print('[PM] checkpoint_path: {}'.format(self.checkpoint_path))

        # Local policy net
        with tf.variable_scope("policy_eval"):
            if LSTM_POLICY:
                self.policy_net = LSTMPolicyEstimator(policy_net.num_outputs)
            else:
                self.policy_net = PolicyEstimator(policy_net.num_outputs)

        # Op to copy params from global policy/value net parameters
        self.copy_params_op = make_copy_params_op(
            tf.contrib.slim.get_variables(
                scope="global", collection=tf.GraphKeys.TRAINABLE_VARIABLES),
            tf.contrib.slim.get_variables(
                scope="policy_eval",
                collection=tf.GraphKeys.TRAINABLE_VARIABLES))

    def eval_once(self, sess):
        with sess.as_default(), sess.graph.as_default():
            # Copy params to local model
            global_step, _ = sess.run(
                [tf.contrib.framework.get_global_step(), self.copy_params_op])

            # Run an episode
            done = False
            state = self.env.reset()
            if LSTM_POLICY:
                self.policy_net.reset_lstm_features()
            total_reward = 0.0
            episode_length = 0
            while not done:
                if LSTM_POLICY:
                    action_probs = self.policy_net.action_inference(state)
                else:
                    action_probs = self.policy_net.action_prediction(state)
                action = np.random.choice(np.arange(len(action_probs)),
                                          p=action_probs)
                next_state, reward, done, _ = self.env.step(action)
                # next_state = atari_helpers.atari_make_next_state(state, self.sp.process(next_state))
                total_reward += reward
                episode_length += 1
                state = next_state

            # Add summaries
            episode_summary = tf.Summary()
            episode_summary.value.add(simple_value=total_reward,
                                      tag="eval/total_reward")
            episode_summary.value.add(simple_value=episode_length,
                                      tag="eval/episode_length")
            self.summary_writer.add_summary(episode_summary, global_step)
            self.summary_writer.flush()

            if self.saver is not None:
                self.saver.save(sess, self.checkpoint_path)

            tf.logging.info(
                "Eval results at step {}: total_reward {}, episode_length {}".
                format(global_step, total_reward, episode_length))

            return total_reward, episode_length

    def continuous_eval(self, eval_every, sess, coord):
        """
        Continuously evaluates the policy every [eval_every] seconds.
        """
        try:
            while not coord.should_stop():
                self.eval_once(sess)
                # Sleep until next evaluation cycle
                time.sleep(eval_every)
        except tf.errors.CancelledError:
            return