def eval_agent(env: py_environment.PyEnvironment, tf_agent: agent.DQNAgent, n_episodes: int, reward_vector: bool = False) -> np.ndarray: results = [] for _ in tqdm(range(n_episodes)): ts = env.reset() observations = ts.observation episode_reward = 0 done = False while not done: action = tf_agent.greedy_policy(observations) ts = env.step(action) observations, reward, done = ts.observation, ts.reward, ts.is_last( ) episode_reward += reward assert np.isclose(episode_reward, env._prev_step_utility, atol=1e-05) if reward_vector: results.append([ observations['utility_representation'], np.copy(env._cumulative_rewards) ]) else: results.append(episode_reward) if reward_vector: results = np.array(results, dtype='object') else: results = np.array(results) return results
def validate_py_environment( environment: py_environment.PyEnvironment, episodes: int = 5, observation_and_action_constraint_splitter: Optional[ types.Splitter] = None): """Validates the environment follows the defined specs.""" time_step_spec = environment.time_step_spec() action_spec = environment.action_spec() random_policy = random_py_policy.RandomPyPolicy( time_step_spec=time_step_spec, action_spec=action_spec, observation_and_action_constraint_splitter=( observation_and_action_constraint_splitter)) if environment.batch_size is not None: batched_time_step_spec = array_spec.add_outer_dims_nest( time_step_spec, outer_dims=(environment.batch_size, )) else: batched_time_step_spec = time_step_spec episode_count = 0 time_step = environment.reset() while episode_count < episodes: if not array_spec.check_arrays_nest(time_step, batched_time_step_spec): raise ValueError('Given `time_step`: %r does not match expected ' '`time_step_spec`: %r' % (time_step, batched_time_step_spec)) action = random_policy.action(time_step).action time_step = environment.step(action) episode_count += np.sum(time_step.is_last())
def validate_py_environment(environment: py_environment.PyEnvironment, episodes: int = 5): """Validates the environment follows the defined specs.""" time_step_spec = environment.time_step_spec() action_spec = environment.action_spec() random_policy = random_py_policy.RandomPyPolicy( time_step_spec=time_step_spec, action_spec=action_spec) episode_count = 0 time_step = environment.reset() while episode_count < episodes: if not array_spec.check_arrays_nest(time_step, time_step_spec): raise ValueError( 'Given `time_step`: %r does not match expected `time_step_spec`: %r' % (time_step, time_step_spec)) action = random_policy.action(time_step).action time_step = environment.step(action) if time_step.is_last(): episode_count += 1 time_step = environment.reset()
def agent_play_episode(env: py_environment.PyEnvironment, agent: DQNAgent) -> None: time_step = env.reset() plt.figure(figsize=(11, 6), dpi=200) i = 1 ax = plt.subplot(4, 8, i) render_time_step(time_step, ax) while not time_step.is_last(): action = agent.greedy_policy(time_step.observation) time_step = env.step(action) i += 1 ax = plt.subplot(4, 8, i) render_time_step(time_step, ax, action) plt.tight_layout() plt.show()