示例#1
0
def eval_agent(env: py_environment.PyEnvironment,
               tf_agent: agent.DQNAgent,
               n_episodes: int,
               reward_vector: bool = False) -> np.ndarray:

    results = []
    for _ in tqdm(range(n_episodes)):
        ts = env.reset()
        observations = ts.observation

        episode_reward = 0
        done = False
        while not done:
            action = tf_agent.greedy_policy(observations)
            ts = env.step(action)
            observations, reward, done = ts.observation, ts.reward, ts.is_last(
            )

            episode_reward += reward

        assert np.isclose(episode_reward, env._prev_step_utility, atol=1e-05)
        if reward_vector:
            results.append([
                observations['utility_representation'],
                np.copy(env._cumulative_rewards)
            ])
        else:
            results.append(episode_reward)

    if reward_vector:
        results = np.array(results, dtype='object')
    else:
        results = np.array(results)

    return results
示例#2
0
def validate_py_environment(
    environment: py_environment.PyEnvironment,
    episodes: int = 5,
    observation_and_action_constraint_splitter: Optional[
        types.Splitter] = None):
    """Validates the environment follows the defined specs."""
    time_step_spec = environment.time_step_spec()
    action_spec = environment.action_spec()

    random_policy = random_py_policy.RandomPyPolicy(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        observation_and_action_constraint_splitter=(
            observation_and_action_constraint_splitter))

    if environment.batch_size is not None:
        batched_time_step_spec = array_spec.add_outer_dims_nest(
            time_step_spec, outer_dims=(environment.batch_size, ))
    else:
        batched_time_step_spec = time_step_spec

    episode_count = 0
    time_step = environment.reset()

    while episode_count < episodes:
        if not array_spec.check_arrays_nest(time_step, batched_time_step_spec):
            raise ValueError('Given `time_step`: %r does not match expected '
                             '`time_step_spec`: %r' %
                             (time_step, batched_time_step_spec))

        action = random_policy.action(time_step).action
        time_step = environment.step(action)

        episode_count += np.sum(time_step.is_last())
示例#3
0
def validate_py_environment(environment: py_environment.PyEnvironment,
                            episodes: int = 5):
    """Validates the environment follows the defined specs."""
    time_step_spec = environment.time_step_spec()
    action_spec = environment.action_spec()

    random_policy = random_py_policy.RandomPyPolicy(
        time_step_spec=time_step_spec, action_spec=action_spec)

    episode_count = 0
    time_step = environment.reset()

    while episode_count < episodes:
        if not array_spec.check_arrays_nest(time_step, time_step_spec):
            raise ValueError(
                'Given `time_step`: %r does not match expected `time_step_spec`: %r'
                % (time_step, time_step_spec))

        action = random_policy.action(time_step).action
        time_step = environment.step(action)

        if time_step.is_last():
            episode_count += 1
            time_step = environment.reset()
示例#4
0
def agent_play_episode(env: py_environment.PyEnvironment, agent: DQNAgent) -> None:
    time_step = env.reset()

    plt.figure(figsize=(11, 6), dpi=200)
    i = 1
    ax = plt.subplot(4, 8, i)
    render_time_step(time_step, ax)
    while not time_step.is_last():
        action = agent.greedy_policy(time_step.observation)
        time_step = env.step(action)
        i += 1
        ax = plt.subplot(4, 8, i)
        render_time_step(time_step, ax, action)

    plt.tight_layout()
    plt.show()