def run_n_steps(self, n, sess):
        transitions = []
        for _ in range(n):
            # Take a step
            action_probs = self._policy_net_predict(self.state, sess)
            action = np.random.choice(np.arange(len(action_probs)),
                                      p=action_probs)
            next_state, reward, done, _ = self.env.step(action)
            next_state = atari_helpers.atari_make_next_state(
                self.state, self.sp.process(next_state))

            # Store transition
            transitions.append(
                Transition(state=self.state,
                           action=action,
                           reward=reward,
                           next_state=next_state,
                           done=done))

            # Increase local and global counters
            local_t = next(self.local_counter)
            global_t = next(self.global_counter)

            if local_t % 100 == 0:
                tf.logging.info("{}: local Step {}, global step {}".format(
                    self.name, local_t, global_t))

            if done:
                self.state = atari_helpers.atari_make_initial_state(
                    self.sp.process(self.env.reset()))
                break
            else:
                self.state = next_state
        return transitions, local_t, global_t
  def eval_once(self, sess):
    with sess.as_default(), sess.graph.as_default():
      # Copy params to local model
      global_step, _ = sess.run([tf.contrib.framework.get_global_step(), self.copy_params_op])

      # Run an episode
      done = False
      state = atari_helpers.atari_make_initial_state(self.sp.process(self.env.reset()))
      total_reward = 0.0
      episode_length = 0
      while not done:
        action_probs = self._policy_net_predict(state, sess)
        action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
        next_state, reward, done, _ = self.env.step(action)
        next_state = atari_helpers.atari_make_next_state(state, self.sp.process(next_state))
        total_reward += reward
        episode_length += 1
        state = next_state

      # Add summaries
      episode_summary = tf.Summary()
      episode_summary.value.add(simple_value=total_reward, tag="eval/total_reward")
      episode_summary.value.add(simple_value=episode_length, tag="eval/episode_length")
      self.summary_writer.add_summary(episode_summary, global_step)
      self.summary_writer.flush()

      if self.saver is not None:
        self.saver.save(sess, self.checkpoint_path)

      tf.logging.info("Eval results at step {}: total_reward {}, episode_length {}".format(global_step, total_reward, episode_length))

      f_reward.write(str(global_step) + " " + str(total_reward) + " " + str(episode_length) + "\n")

      return total_reward, episode_length
  def eval_once(self, sess):
    with sess.as_default(), sess.graph.as_default():
      # Copy params to local model
      global_step, _ = sess.run([tf.contrib.framework.get_global_step(), self.copy_params_op])

      # Run an episode
      done = False
      state = atari_helpers.atari_make_initial_state(self.sp.process(self.env.reset()))
      total_reward = 0.0
      episode_length = 0
      while not done:
        action_probs = self._policy_net_predict(state, sess)
        action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
        next_state, reward, done, _ = self.env.step(action)
        next_state = atari_helpers.atari_make_next_state(state, self.sp.process(next_state))
        total_reward += reward
        episode_length += 1
        state = next_state

      # Add summaries
      episode_summary = tf.Summary()
      episode_summary.value.add(simple_value=total_reward, tag="eval/total_reward")
      episode_summary.value.add(simple_value=episode_length, tag="eval/episode_length")
      self.summary_writer.add_summary(episode_summary, global_step)
      self.summary_writer.flush()

      if self.saver is not None:
        self.saver.save(sess, self.checkpoint_path)

      tf.logging.info("Eval results at step {}: total_reward {}, episode_length {}".format(global_step, total_reward, episode_length))

      return total_reward, episode_length
    def run(self, sess, coord, t_max):
        with sess.as_default(), sess.graph.as_default():
            # Initial state
            self.state = atari_helpers.atari_make_initial_state(
                self.sp.process(self.env.reset()))
            try:
                while not coord.should_stop():
                    # Copy Parameters from the global networks
                    sess.run(self.copy_params_op)

                    # Collect some experience
                    transitions, local_t, global_t = self.run_n_steps(
                        t_max, sess)

                    if self.max_global_steps is not None and global_t >= self.max_global_steps:
                        tf.logging.info(
                            "Reached global step {}. Stopping.".format(
                                global_t))
                        coord.request_stop()
                        return

                    # Update the global networks
                    self.update(transitions, sess)

            except tf.errors.CancelledError:
                return
예제 #5
0
  def run_n_steps(self, n, sess):
    transitions = []
    for _ in range(n):
      # Take a step
      action_probs = self._policy_net_predict(self.state, sess)
      action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
      next_state, reward, done, _ = self.env.step(action)
      next_state = atari_helpers.atari_make_next_state(self.state, self.sp.process(next_state))

      # Store transition
      transitions.append(Transition(
        state=self.state, action=action, reward=reward, next_state=next_state, done=done))

      # Increase local and global counters
      local_t = next(self.local_counter)
      global_t = next(self.global_counter)

      if local_t % 100 == 0:
        tf.logging.info("{}: local Step {}, global step {}".format(self.name, local_t, global_t))

      if done:
        self.state = atari_helpers.atari_make_initial_state(self.sp.process(self.env.reset()))
        break
      else:
        self.state = next_state
    return transitions, local_t, global_t
    def testPredict(self):
        env = make_env()
        sp = StateProcessor()
        estimator = PolicyEstimator(len(VALID_ACTIONS))

        with self.test_session() as sess:
            sess.run(tf.initialize_all_variables())

            # Generate a state
            state = sp.process(env.reset())
            processed_state = atari_helpers.atari_make_initial_state(state)
            processed_states = np.array([processed_state])

            # Run feeds
            feed_dict = {
                estimator.states: processed_states,
                estimator.targets: [1.0],
                estimator.actions: [1]
            }

            loss = sess.run(estimator.loss, feed_dict)
            pred = sess.run(estimator.predictions, feed_dict)

            # Assertions
            self.assertTrue(loss != 0.0)
            self.assertEqual(pred["probs"].shape, (1, len(VALID_ACTIONS)))
            self.assertEqual(pred["logits"].shape, (1, len(VALID_ACTIONS)))
    def testGradient(self):
        env = make_env()
        sp = StateProcessor()
        estimator = PolicyEstimator(len(VALID_ACTIONS))
        grads = [g for g, _ in estimator.grads_and_vars]

        with self.test_session() as sess:
            sess.run(tf.initialize_all_variables())

            # Generate a state
            state = sp.process(env.reset())
            processed_state = atari_helpers.atari_make_initial_state(state)
            processed_states = np.array([processed_state])

            # Run feeds to get gradients
            feed_dict = {
                estimator.states: processed_states,
                estimator.targets: [1.0],
                estimator.actions: [1]
            }

            grads_ = sess.run(grads, feed_dict)

            # Apply calculated gradients
            grad_feed_dict = {k: v for k, v in zip(grads, grads_)}
            _ = sess.run(estimator.train_op, grad_feed_dict)
예제 #8
0
파일: worker.py 프로젝트: 404akhan/a3c
  def run_n_steps(self, n, sess):
    transitions = []
    for _ in range(n):
      # Take a step
      action_probs = self._policy_net_predict(self.state, sess)
      action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
      next_state, reward, done, _ = self.env.step(action)
      next_state = atari_helpers.atari_make_next_state(self.state, self.sp.process(next_state))
      self.total_reward += reward
      self.episode_length += 1

      # Store transition
      transitions.append(Transition(
        state=self.state, action=action, reward=reward, next_state=next_state, done=done))

      # Increase local and global counters
      local_t = next(self.local_counter)
      global_t = next(self.global_counter)

      if local_t % 100 == 0:
        tf.logging.info("{}: local Step {}, global step {}".format(self.name, local_t, global_t))

      if done:
        self.state = atari_helpers.atari_make_initial_state(self.sp.process(self.env.reset()))
        f = open('logs_policy.out', 'a')
        f.write("agent {}, local {}, global {}, total_reward {}, episode_length {}\n".format(
          self.name, self.local_counter, self.global_counter, self.total_reward, self.episode_length))
        f.close()
        self.total_reward = 0
        self.episode_length = 0
        break
      else:
        self.state = next_state
    return transitions, local_t, global_t
예제 #9
0
    def run_n_steps(self, n, sess):
        transitions = []
        for _ in range(n):
            # Take a step
            action_probs = self._policy_net_predict(self.state, sess)
            action = np.random.choice(np.arange(len(action_probs)),
                                      p=action_probs)
            repetition_probs = self._repetition_net_predict(self.state, sess)
            repetition = np.random.choice(np.arange(len(repetition_probs)),
                                          p=repetition_probs)

            rewards_collected = []

            # print("repetition", self.name,repetition)
            for rep in range(repetition + 1):
                next_state, reward, done, _ = self.env.step(action)
                # print(self.name,rep)
                # print("action",action)
                next_state = atari_helpers.atari_make_next_state(
                    self.state, self.sp.process(next_state))
                rewards_collected.append(reward)

                # Increase local and global counters
                local_t = next(self.local_counter)
                global_t = next(self.global_counter)

                if local_t % 100 == 0:
                    tf.logging.info("{}: local Step {}, global step {}".format(
                        self.name, local_t, global_t))

                if done:
                    transitions.append(
                        Transition(state=self.state,
                                   action=action,
                                   repetition=repetition,
                                   reward=sum(rewards_collected) /
                                   len(rewards_collected),
                                   next_state=next_state,
                                   done=done))

                    self.state = atari_helpers.atari_make_initial_state(
                        self.sp.process(self.env.reset()))
                    break
                else:
                    if rep == repetition:
                        transitions.append(
                            Transition(state=self.state,
                                       action=action,
                                       repetition=repetition,
                                       reward=sum(rewards_collected) /
                                       len(rewards_collected),
                                       next_state=next_state,
                                       done=done))

                    self.state = next_state

        return transitions, local_t, global_t
예제 #10
0
    def testValueNetPredict(self):
        w = Worker(name="test",
                   env=make_env(),
                   policy_net=self.global_policy_net,
                   value_net=self.global_value_net,
                   global_counter=self.global_counter,
                   discount_factor=self.discount_factor)

        with self.test_session() as sess:
            sess.run(tf.initialize_all_variables())
            state = self.sp.process(self.env.reset())
            processed_state = atari_helpers.atari_make_initial_state(state)
            state_value = w._value_net_predict(processed_state, sess)
            self.assertEqual(state_value.shape, ())
  def testValueNetPredict(self):
    w = Worker(
      name="test",
      env=make_env(),
      policy_net=self.global_policy_net,
      value_net=self.global_value_net,
      global_counter=self.global_counter,
      discount_factor=self.discount_factor)

    with self.test_session() as sess:
      sess.run(tf.initialize_all_variables())
      state = self.sp.process(self.env.reset())
      processed_state = atari_helpers.atari_make_initial_state(state)
      state_value = w._value_net_predict(processed_state, sess)
      self.assertEqual(state_value.shape, ())
예제 #12
0
  def run(self, sess, coord, t_max):
    with sess.as_default(), sess.graph.as_default():
      # Initial state
      self.state = atari_helpers.atari_make_initial_state(self.sp.process(self.env.reset()))
      try:
        while not coord.should_stop():
          # Copy Parameters from the global networks
          sess.run(self.copy_params_op)

          # Collect some experience
          transitions, local_t, global_t = self.run_n_steps(t_max, sess)

          if self.max_global_steps is not None and global_t >= self.max_global_steps:
            tf.logging.info("Reached global step {}. Stopping.".format(global_t))
            coord.request_stop()
            return

          # Update the global networks
          self.update(transitions, sess)

      except tf.errors.CancelledError:
        return
예제 #13
0
    def testRunNStepsAndUpdate(self):
        w = Worker(name="test",
                   env=make_env(),
                   policy_net=self.global_policy_net,
                   value_net=self.global_value_net,
                   global_counter=self.global_counter,
                   discount_factor=self.discount_factor)

        with self.test_session() as sess:
            sess.run(tf.initialize_all_variables())
            state = self.sp.process(self.env.reset())
            processed_state = atari_helpers.atari_make_initial_state(state)
            w.state = processed_state
            transitions, local_t, global_t = w.run_n_steps(10, sess)
            policy_net_loss, value_net_loss, policy_net_summaries, value_net_summaries = w.update(
                transitions, sess)

        self.assertEqual(len(transitions), 10)
        self.assertIsNotNone(policy_net_loss)
        self.assertIsNotNone(value_net_loss)
        self.assertIsNotNone(policy_net_summaries)
        self.assertIsNotNone(value_net_summaries)
  def testRunNStepsAndUpdate(self):
    w = Worker(
      name="test",
      env=make_env(),
      policy_net=self.global_policy_net,
      value_net=self.global_value_net,
      global_counter=self.global_counter,
      discount_factor=self.discount_factor)

    with self.test_session() as sess:
      sess.run(tf.initialize_all_variables())
      state = self.sp.process(self.env.reset())
      processed_state = atari_helpers.atari_make_initial_state(state)
      w.state = processed_state
      transitions, local_t, global_t = w.run_n_steps(10, sess)
      policy_net_loss, value_net_loss, policy_net_summaries, value_net_summaries = w.update(transitions, sess)

    self.assertEqual(len(transitions), 10)
    self.assertIsNotNone(policy_net_loss)
    self.assertIsNotNone(value_net_loss)
    self.assertIsNotNone(policy_net_summaries)
    self.assertIsNotNone(value_net_summaries)
예제 #15
0
class Worker(object):

    def __init__(self, name, env, policy_net, value_net, global_counter,
                 discount_factor=0.99, summary_writer=None,
                 max_global_steps=None):
        self.name = name
        self.discount_factor = discount_factor
        self.max_global_steps = max_global_step
        self.global_step = tf.contrib.framework.get_global_step()
        self.global_policy_net = policy_net
        self.global_value_net = value_net
        self.global_counter = global_counter
        self.local_counter = itertools.count()
        self.sp = StateProcessor()
        self.summary_writer = summary_writer
        self.env = env

        with tf.variable_scope(name):
            self.policy_net = PolicyEstimator(policy_net.num_outputs)
            self.value_net = ValueEstimator(reuse=True)

        self.copy_params_op = make_copy_params_op(
            tf.contrib.slim.get_variables(
                scope="global",
                collection=tf.GraphKeys.TRAINABLE_VARIABLES)
            tf.contrib.slim.get_variables(
                scope=self.name+'/',
                ollection=tf.GraphKeys.TRAINABLE_VARIABLES)
        )

        self.vnet_train_op = make_copy_params_op(
            self.value_net, self.global_value_net)
        self.pnet_train_op = make_copy_params_op(
            self.policy_net, self.global_policy_net)

        self.state = None

    def run(self, sess, coord, t_max):
        with sess.as_default(), sess.graph.as_Default():
            self.state = atari_helpers.atari_make_initial_state(
                self.sp.process(self.env.reset()))
            try:
                while not coord.should_stop():
                    sess.run(self.copy_params_op)

                    transitions, local_t, global_t = self.run_n_steps(
                        t_max, sess)

                    if self.max_global_steps is not None
                    and global_t >= self.max_global_steps:
                        tf.logging.info(
                            "Reached global step {}. Stopping."
                            .format(global_t))
                        coord.request_stop()
                        return

                    self.update(transitions, sess)

            except tf.errors.CancelledError:
                return

    def _policy_net_predict(self, state, sess):
        feed_dic = {self.policy_net.states: [states]}
        preds = sess.run(self.policy_net.predictions, feed_dic)
        return preds["probs"][0]

    def _value_net_predict(self, state, sess):
        feed_dict = {self.value_net.states: [state]}
        preds = sess.run(self.value_net.predictions, feed_dict)
        return preds["logits"][0]

    def run_n_steps(self, n, sess):
        transitions = []

        for _ in range(n):
            action_probes = self._policy_net_predict(self.state, sess)
            action = np.random.choice(
                np.arange(len(action_probs)), p=action_probs)
            next_state, reward, done, _ = self.env.step(action)
            next_state = atari_helpers.atari_make_next_state(
                self.state, self.sp.process(next_state))

            transitions.append(Transition(s
                                          tate=self.state,
                                          action=action,
                                          reward=reward,
                                          next_state=next_state,
                                          done=done))

            local_t = next(self.local_counter)
            global_t = next(self.global_counter)

            if local_t % 100 == 0:
                tf.logging.info("{}: local step {}, global step {}". format(
                    self.name, local_t, global_t))

            if done:
                self.state = atari_helpers.atari_make_initial_state(
                    self.sp.process(self.env.reset()))
                break
            else:
                self.state = next_state