Beispiel #1
0
  def test_profiling(self):
    cartpole_env = gym.make('CartPole-v1')
    env = gym_wrapper.GymWrapper(cartpole_env)
    profile = [None]
    def profile_fn(p):
      self.assertIsInstance(p, cProfile.Profile)
      profile[0] = p

    env = wrappers.PerformanceProfiler(
        env, process_profile_fn=profile_fn,
        process_steps=2)

    env.reset()

    # Resets are also profiled.
    s = pstats.Stats(env._profile)
    self.assertGreater(s.total_calls, 0)  # pytype: disable=attribute-error

    for _ in range(2):
      env.step(np.array(1, dtype=np.int32))

    self.assertIsNotNone(profile[0])
    previous_profile = profile[0]

    updated_s = pstats.Stats(profile[0])
    self.assertGreater(updated_s.total_calls, s.total_calls)  # pytype: disable=attribute-error

    for _ in range(2):
      env.step(np.array(1, dtype=np.int32))

    self.assertIsNotNone(profile[0])
    # We saw a new profile.
    self.assertNotEqual(profile[0], previous_profile)
Beispiel #2
0
def profile_env(env_str, max_ep_len, n_steps=None, env_wrappers=[]):
  n_steps = n_steps or max_ep_len * 2
  profile = [None]

  def profile_fn(p):
    assert isinstance(p, cProfile.Profile)
    profile[0] = p

  py_env = suite_gym.load(env_str, gym_env_wrappers=env_wrappers,
                          max_episode_steps=max_ep_len)
  env = wrappers.PerformanceProfiler(
    py_env, process_profile_fn=profile_fn,
    process_steps=n_steps)
  policy = random_py_policy.RandomPyPolicy(env.time_step_spec(), env.action_spec())

  driver = py_driver.PyDriver(env, policy, [], max_steps=n_steps)
  time_step = env.reset()
  policy_state = policy.get_initial_state()
  for _ in range(n_steps):
    time_step, policy_state = driver.run(time_step, policy_state)
  stats = pstats.Stats(profile[0])
  stats.print_stats()