コード例 #1
0
 def test_episodic_observer_overflow_episode_bypass(self):
     env1 = _env_creator(3)()
     env2 = _env_creator(4)()
     with _create_add_episode_observer_fn(table_name='test_table',
                                          max_sequence_length=4,
                                          priority=1,
                                          bypass_partial_episodes=True)(
                                              self._client) as observer:
         policy = _create_random_policy_from_env(env1)
         # env1 -> writes only ONE episode. Note that `max_sequence_length`
         # must be one more than episode length. As in TF-Agents, we append
         # a trajectory as the `LAST` step.
         driver = py_driver.PyDriver(env1,
                                     policy,
                                     observers=[observer],
                                     max_steps=6)
         driver.run(env1.reset())
         # env2 -> writes NO episodes (all of them has length >
         # `max_sequence_length`)
         policy = _create_random_policy_from_env(env2)
         driver = py_driver.PyDriver(env2,
                                     policy,
                                     observers=[observer],
                                     max_steps=6)
         driver.run(env2.reset())
     self.assertEqual(1, self._writer.create_item.call_count)
コード例 #2
0
ファイル: metric_utils.py プロジェクト: yrbahn/agents
def compute(metrics, environment, policy, num_episodes=1):
    """Compute metrics using `policy` on the `environment`.

  Args:
    metrics: List of metrics to compute.
    environment: py_environment instance.
    policy: py_policy instance used to step the environment. A tf_policy can be
      used in_eager_mode.
    num_episodes: Number of episodes to compute the metrics over.

  Returns:
    A dictionary of results {metric_name: metric_value}
  """
    for metric in metrics:
        metric.reset()

    time_step = environment.reset()
    policy_state = policy.get_initial_state(environment.batch_size)

    driver = py_driver.PyDriver(environment,
                                policy,
                                observers=metrics,
                                max_steps=None,
                                max_episodes=num_episodes)
    driver.run(time_step, policy_state)

    results = [(metric.name, metric.result()) for metric in metrics]
    return collections.OrderedDict(results)
コード例 #3
0
    def test_with_py_driver(self):
        env = driver_test_utils.PyEnvironmentMock()
        policy = driver_test_utils.PyPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        trajectory_spec = trajectory.from_transition(env.time_step_spec(),
                                                     policy.policy_step_spec,
                                                     env.time_step_spec())
        trajectory_spec = tensor_spec.from_spec(trajectory_spec)

        tfrecord_observer = example_encoding_dataset.TFRecordObserver(
            self.dataset_path, trajectory_spec, py_mode=True)

        driver = py_driver.PyDriver(env,
                                    policy, [tfrecord_observer],
                                    max_steps=10)
        time_step = env.reset()
        driver.run(time_step)
        tfrecord_observer.flush()
        tfrecord_observer.close()

        dataset = example_encoding_dataset.load_tfrecord_dataset(
            [self.dataset_path], buffer_size=2, as_trajectories=True)

        iterator = eager_utils.dataset_iterator(dataset)
        sample = self.evaluate(eager_utils.get_next(iterator))
        self.assertIsInstance(sample, trajectory.Trajectory)
コード例 #4
0
    def testRunOnce(self, max_steps, max_episodes, expected_steps):
        env = driver_test_utils.PyEnvironmentMock()
        policy = driver_test_utils.PyPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        replay_buffer_observer = MockReplayBufferObserver()
        transition_replay_buffer_observer = MockReplayBufferObserver()
        driver = py_driver.PyDriver(
            env,
            policy,
            observers=[replay_buffer_observer],
            transition_observers=[transition_replay_buffer_observer],
            max_steps=max_steps,
            max_episodes=max_episodes,
        )

        initial_time_step = env.reset()
        initial_policy_state = policy.get_initial_state()
        driver.run(initial_time_step, initial_policy_state)
        trajectories = replay_buffer_observer.gather_all()
        self.assertEqual(trajectories, self._trajectories[:expected_steps])

        transitions = transition_replay_buffer_observer.gather_all()
        self.assertLen(transitions, expected_steps)
        # TimeStep, Action, NextTimeStep
        self.assertLen(transitions[0], 3)
コード例 #5
0
  def testBatchedEnvironment(self, max_steps, max_episodes, expected_length):

    expected_trajectories = [
        trajectory.Trajectory(
            step_type=np.array([0, 0]),
            observation=np.array([0, 0]),
            action=np.array([2, 1]),
            policy_info=np.array([4, 2]),
            next_step_type=np.array([1, 1]),
            reward=np.array([1., 1.]),
            discount=np.array([1., 1.])),
        trajectory.Trajectory(
            step_type=np.array([1, 1]),
            observation=np.array([2, 1]),
            action=np.array([1, 2]),
            policy_info=np.array([2, 4]),
            next_step_type=np.array([2, 1]),
            reward=np.array([1., 1.]),
            discount=np.array([0., 1.])),
        trajectory.Trajectory(
            step_type=np.array([2, 1]),
            observation=np.array([3, 3]),
            action=np.array([2, 1]),
            policy_info=np.array([4, 2]),
            next_step_type=np.array([0, 2]),
            reward=np.array([0., 1.]),
            discount=np.array([1., 0.]))
    ]

    env1 = driver_test_utils.PyEnvironmentMock(final_state=3)
    env2 = driver_test_utils.PyEnvironmentMock(final_state=4)
    env = batched_py_environment.BatchedPyEnvironment([env1, env2])

    policy = driver_test_utils.PyPolicyMock(
        env.time_step_spec(),
        env.action_spec(),
        initial_policy_state=np.array([1, 2]))
    replay_buffer_observer = MockReplayBufferObserver()

    driver = py_driver.PyDriver(
        env,
        policy,
        observers=[replay_buffer_observer],
        max_steps=max_steps,
        max_episodes=max_episodes,
    )
    initial_time_step = env.reset()
    initial_policy_state = policy.get_initial_state()
    driver.run(initial_time_step, initial_policy_state)
    trajectories = replay_buffer_observer.gather_all()

    self.assertEqual(
        len(trajectories), len(expected_trajectories[:expected_length]))

    for t1, t2 in zip(trajectories, expected_trajectories[:expected_length]):
      for t1_field, t2_field in zip(t1, t2):
        self.assertAllEqual(t1_field, t2_field)
コード例 #6
0
def collect_episode(environment, policy, num_episodes):

    driver = py_driver.PyDriver(environment,
                                py_tf_eager_policy.PyTFEagerPolicy(
                                    policy, use_tf_function=True),
                                [rb_observer],
                                max_episodes=num_episodes)
    initial_time_step = environment.reset()
    driver.run(initial_time_step)
コード例 #7
0
 def test_episodic_observer_overflow_episode_raise_value_error(self):
   env = _env_creator(3)()
   with _create_add_episode_observer_fn(
       table_name='test_table', max_sequence_length=2,
       priority=1)(self._client) as observer:
     policy = _create_random_policy_from_env(env)
     driver = py_driver.PyDriver(
         env, policy, observers=[observer], max_steps=4)
     with self.assertRaises(ValueError):
       driver.run(env.reset())
コード例 #8
0
def collect_episode(environment, policy, num_episodes, replay_buffer_observer):
    """Collect game episode trajectories."""
    initial_time_step = environment.reset()

    driver = py_driver.PyDriver(environment,
                                py_tf_eager_policy.PyTFEagerPolicy(
                                    policy, use_tf_function=True),
                                [replay_buffer_observer],
                                max_episodes=num_episodes)
    initial_time_step = environment.reset()
    driver.run(initial_time_step)
コード例 #9
0
  def test_observer_writes(self, create_observer_fn, env_fn, expected_items,
                           writer_call_counts, max_steps, append_count):
    env = env_fn()
    with create_observer_fn(self._client) as observer:
      policy = _create_random_policy_from_env(env)
      driver = py_driver.PyDriver(
          env, policy, observers=[observer], max_steps=max_steps)
      driver.run(env.reset())

    self.assertEqual(writer_call_counts, self._writer.call_count)
    self.assertEqual(append_count, self._writer.append.call_count)
    self.assertEqual(expected_items, self._writer.create_item.call_count)
コード例 #10
0
 def testValueErrorOnInvalidArgs(self, max_steps, max_episodes):
     env = driver_test_utils.PyEnvironmentMock()
     policy = driver_test_utils.PyPolicyMock(env.time_step_spec(),
                                             env.action_spec())
     replay_buffer_observer = MockReplayBufferObserver()
     with self.assertRaises(ValueError):
         py_driver.PyDriver(
             env,
             policy,
             observers=[replay_buffer_observer],
             max_steps=max_steps,
             max_episodes=max_episodes,
         )
コード例 #11
0
 def test_episodic_observer_num_steps(self):
     create_observer_fn = _create_add_episode_observer_fn(
         table_name='test_table', max_sequence_length=8, priority=3)
     env = _env_creator(3)()
     with create_observer_fn(self._client) as observer:
         policy = _create_random_policy_from_env(env)
         driver = py_driver.PyDriver(env,
                                     policy,
                                     observers=[observer],
                                     max_steps=10)
         driver.run(env.reset())
         # After each episode, we reset `cached_steps`.
         # We run the driver for 3 full episode and one step.
         self.assertEqual(observer._cached_steps, 1)
コード例 #12
0
  def test_observer_writes_multi_tables(self):
    episode_length = 3
    collect_step_count = 6
    table_count = 2
    create_observer_fn = _create_add_sequence_observer_fn(
        table_name=['test_table1', 'test_table2'],
        sequence_length=episode_length,
        stride_length=episode_length)
    env = _env_creator(episode_length)()
    with create_observer_fn(self._client) as observer:
      policy = _create_random_policy_from_env(env)
      driver = py_driver.PyDriver(
          env, policy, observers=[observer], max_steps=collect_step_count)
      driver.run(env.reset())

    self.assertEqual(table_count * int(collect_step_count / episode_length),
                     self._writer.create_item.call_count)
コード例 #13
0
def run_env(env, policy, max_episodes, max_steps=None):
    logging.info('Running policy on env ..')
    replay_buffer = []
    metrics = [
        py_metrics.AverageReturnMetric(),
        py_metrics.AverageEpisodeLengthMetric()
    ]
    observers = [replay_buffer.append]
    observers.extend(metrics)
    driver = py_driver.PyDriver(env,
                                policy,
                                observers,
                                max_steps=max_steps,
                                max_episodes=max_episodes)
    initial_time_step = env.reset()
    initial_state = policy.get_initial_state(1)
    driver.run(initial_time_step, initial_state)
    return replay_buffer, metrics
コード例 #14
0
  def test_observer_resets(self, create_observer_fn,
                           reset_with_write_cached_steps, append_count,
                           expected_items, append_count_from_reset,
                           expected_items_from_reset):
    env = _env_creator(5)()
    with create_observer_fn(self._client) as observer:
      policy = _create_random_policy_from_env(env)
      driver = py_driver.PyDriver(
          env, policy, observers=[observer], max_steps=11)
      driver.run(env.reset())

      self.assertEqual(append_count, self._writer.append.call_count)
      self.assertEqual(expected_items, self._writer.create_item.call_count)
      observer.reset(write_cached_steps=reset_with_write_cached_steps)
      self.assertEqual(append_count + append_count_from_reset,
                       self._writer.append.call_count)
      self.assertEqual(expected_items + expected_items_from_reset,
                       self._writer.create_item.call_count)
コード例 #15
0
    def test_trajectory_observer_no_mock(self):
        create_observer_fn = _create_add_trajectory_observer_fn(
            table_name=self._table_name, sequence_length=2)
        env = _env_creator(episode_len=6)()

        self._reverb_client.reset(self._table_name)
        with create_observer_fn(self._reverb_client) as observer:
            policy = _create_random_policy_from_env(env)
            driver = py_driver.PyDriver(env,
                                        policy,
                                        observers=[observer],
                                        max_steps=5)
            driver.run(env.reset())
            # Give it some time for the items to reach Reverb.
            time.sleep(1)

            self.assertEqual(observer._cached_steps, 5)
            self.assertEqual(self._table.info.current_size, 4)
コード例 #16
0
    def test_episodic_observer_no_mock(self):
        create_observer_fn = _create_add_episode_observer_fn(
            table_name=self._table_name, max_sequence_length=8, priority=3)
        env = _env_creator(episode_len=3)()

        self._reverb_client.reset(self._table_name)
        with create_observer_fn(self._reverb_client) as observer:
            policy = _create_random_policy_from_env(env)
            driver = py_driver.PyDriver(env,
                                        policy,
                                        observers=[observer],
                                        max_steps=10)
            driver.run(env.reset())
            # Give it some time for the items to reach Reverb.
            time.sleep(1)

            # We run the driver for 3 full episode and one step.
            self.assertEqual(observer._cached_steps, 1)
            self.assertEqual(self._table.info.current_size, 3)
コード例 #17
0
 def _insert_random_data(self,
                         env,
                         num_steps,
                         sequence_length=2,
                         additional_observers=None):
   """Insert `num_step` random observations into Reverb server."""
   observers = [] if additional_observers is None else additional_observers
   traj_obs = reverb_utils.ReverbAddTrajectoryObserver(
       self._py_client, self._table_name, sequence_length=sequence_length)
   observers.append(traj_obs)
   policy = random_py_policy.RandomPyPolicy(env.time_step_spec(),
                                            env.action_spec())
   driver = py_driver.PyDriver(env,
                               policy,
                               observers=observers,
                               max_steps=num_steps)
   time_step = env.reset()
   driver.run(time_step)
   traj_obs.close()
コード例 #18
0
def profile_env(env_str, max_ep_len, n_steps=None, env_wrappers=[]):
  n_steps = n_steps or max_ep_len * 2
  profile = [None]

  def profile_fn(p):
    assert isinstance(p, cProfile.Profile)
    profile[0] = p

  py_env = suite_gym.load(env_str, gym_env_wrappers=env_wrappers,
                          max_episode_steps=max_ep_len)
  env = wrappers.PerformanceProfiler(
    py_env, process_profile_fn=profile_fn,
    process_steps=n_steps)
  policy = random_py_policy.RandomPyPolicy(env.time_step_spec(), env.action_spec())

  driver = py_driver.PyDriver(env, policy, [], max_steps=n_steps)
  time_step = env.reset()
  policy_state = policy.get_initial_state()
  for _ in range(n_steps):
    time_step, policy_state = driver.run(time_step, policy_state)
  stats = pstats.Stats(profile[0])
  stats.print_stats()
コード例 #19
0
ファイル: py_driver_test.py プロジェクト: tensorflow/agents
  def testPolicyStateReset(self):
    num_episodes = 2
    num_expected_steps = 6

    env = driver_test_utils.PyEnvironmentMock()
    policy = driver_test_utils.PyPolicyMock(env.time_step_spec(),
                                            env.action_spec())
    replay_buffer_observer = MockReplayBufferObserver()
    driver = py_driver.PyDriver(
        env,
        policy,
        observers=[replay_buffer_observer],
        max_steps=None,
        max_episodes=num_episodes,
    )

    time_step = env.reset()
    policy_state = policy.get_initial_state()
    time_step, policy_state = driver.run(time_step, policy_state)
    trajectories = replay_buffer_observer.gather_all()
    self.assertEqual(trajectories, self._trajectories[:num_expected_steps])
    self.assertEqual(num_episodes, policy.get_initial_state_call_count)
コード例 #20
0
ファイル: py_driver_test.py プロジェクト: tensorflow/agents
  def testMultipleRunMaxSteps(self):
    num_steps = 3
    num_expected_steps = 4

    env = driver_test_utils.PyEnvironmentMock()
    policy = driver_test_utils.PyPolicyMock(env.time_step_spec(),
                                            env.action_spec())
    replay_buffer_observer = MockReplayBufferObserver()
    driver = py_driver.PyDriver(
        env,
        policy,
        observers=[replay_buffer_observer],
        max_steps=1,
        max_episodes=None,
    )

    time_step = env.reset()
    policy_state = policy.get_initial_state()
    for _ in range(num_steps):
      time_step, policy_state = driver.run(time_step, policy_state)
    trajectories = replay_buffer_observer.gather_all()
    self.assertEqual(trajectories, self._trajectories[:num_expected_steps])
コード例 #21
0
  def __init__(self,
               env,
               policy,
               train_step,
               steps_per_run=None,
               episodes_per_run=None,
               observers=None,
               metrics=None,
               reference_metrics=None,
               summary_dir=None,
               summary_interval=1000,
               name=""):
    """Initializes an Actor.

    Args:
      env: An instance of either a tf or py environment. Note the policy, and
        observers should match the tf/pyness of the env.
      policy: An instance of a policy used to interact with the environment.
      train_step: A scalar tf.int64 `tf.Variable` which will keep track of the
        number of train steps. This is used for artifacts created like
        summaries.
      steps_per_run: Number of steps to evaluated per run call. See below.
      episodes_per_run: Number of episodes evaluated per run call.
      observers: A list of observers that are notified after every step in the
        environment. Each observer is a callable(trajectory.Trajectory).
      metrics: A list of metric observers.
      reference_metrics: Optional list of metrics for which other metrics are
        plotted against. As an example passing in a metric that tracks number of
        environment episodes will result in having summaries of all other
        metrics over this value. Note summaries against the train_step are done
        by default.
      summary_dir: Path used for summaries. If no path is provided no summaries
        are written.
      summary_interval: How often summaries are written.
      name: Name for the actor used as a prefix to generated summaries.
    """
    self._env = env
    self._policy = policy
    self._train_step = train_step
    self._observers = observers or []
    self._metrics = metrics or []
    self._observers.extend(self._metrics)
    self._reference_metrics = reference_metrics or []
    self._observers.extend(self._reference_metrics)
    self._observers = list(set(self._observers))

    self._write_summaries = bool(summary_dir)  # summary_dir is not None

    if self._write_summaries:
      self._summary_writer = tf.summary.create_file_writer(
          summary_dir, flush_millis=10000)
    else:
      self._summary_writer = NullSummaryWriter()

    self._summary_interval = summary_interval
    # In order to write summaries at `train_step=0` as well.
    self._last_summary = -summary_interval

    self._name = name

    if isinstance(env, py_environment.PyEnvironment):
      self._driver = py_driver.PyDriver(
          env,
          policy,
          self._observers,
          max_steps=steps_per_run,
          max_episodes=episodes_per_run)
    elif isinstance(env, tf_environment.TFEnvironment):
      raise ValueError("Actor doesn't support TFEnvironments yet.")
    else:
      raise ValueError("Unknown environment type.")

    self.reset()