def test_profiling(self): cartpole_env = gym.make('CartPole-v1') env = gym_wrapper.GymWrapper(cartpole_env) profile = [None] def profile_fn(p): self.assertIsInstance(p, cProfile.Profile) profile[0] = p env = wrappers.PerformanceProfiler( env, process_profile_fn=profile_fn, process_steps=2) env.reset() # Resets are also profiled. s = pstats.Stats(env._profile) self.assertGreater(s.total_calls, 0) # pytype: disable=attribute-error for _ in range(2): env.step(np.array(1, dtype=np.int32)) self.assertIsNotNone(profile[0]) previous_profile = profile[0] updated_s = pstats.Stats(profile[0]) self.assertGreater(updated_s.total_calls, s.total_calls) # pytype: disable=attribute-error for _ in range(2): env.step(np.array(1, dtype=np.int32)) self.assertIsNotNone(profile[0]) # We saw a new profile. self.assertNotEqual(profile[0], previous_profile)
def profile_env(env_str, max_ep_len, n_steps=None, env_wrappers=[]): n_steps = n_steps or max_ep_len * 2 profile = [None] def profile_fn(p): assert isinstance(p, cProfile.Profile) profile[0] = p py_env = suite_gym.load(env_str, gym_env_wrappers=env_wrappers, max_episode_steps=max_ep_len) env = wrappers.PerformanceProfiler( py_env, process_profile_fn=profile_fn, process_steps=n_steps) policy = random_py_policy.RandomPyPolicy(env.time_step_spec(), env.action_spec()) driver = py_driver.PyDriver(env, policy, [], max_steps=n_steps) time_step = env.reset() policy_state = policy.get_initial_state() for _ in range(n_steps): time_step, policy_state = driver.run(time_step, policy_state) stats = pstats.Stats(profile[0]) stats.print_stats()