예제 #1
0
def eval_agent(env: py_environment.PyEnvironment,
               tf_agent: agent.DQNAgent,
               n_episodes: int,
               reward_vector: bool = False) -> np.ndarray:
    results = []
    for _ in tqdm(range(n_episodes)):
        ts = env.reset()
        observations = ts.observation

        episode_reward = 0
        done = False
        while not done:
            action = tf_agent.greedy_policy(observations)
            ts = env.step(action)
            observations, reward, done = ts.observation, ts.reward, ts.is_last(
            )

            episode_reward += reward

        assert np.isclose(episode_reward, env._prev_step_utility, atol=1e-05)
        if reward_vector:
            results.append([
                observations['utility_representation'],
                np.copy(env._cumulative_rewards)
            ])
        else:
            results.append(episode_reward)

    if reward_vector:
        results = np.array(results, dtype='object')
    else:
        results = np.array(results)

    return results
예제 #2
0
파일: wrappers.py 프로젝트: wau/agents
    def __init__(
        self,
        env: py_environment.PyEnvironment,
        history_length: int = 3,
        include_actions: bool = False,
        tile_first_step_obs: bool = False,
    ):
        """Initializes a HistoryWrapper.

    Args:
      env: Environment to wrap.
      history_length: Length of the history to attach.
      include_actions: Whether actions should be included in the history.
      tile_first_step_obs: If True the observation on reset is tiled to fill the
       history.
    """
        super(HistoryWrapper, self).__init__(env)
        self._history_length = history_length
        self._include_actions = include_actions
        self._tile_first_step_obs = tile_first_step_obs

        self._zero_observation = self._zeros_from_spec(env.observation_spec())
        self._zero_action = self._zeros_from_spec(env.action_spec())

        self._observation_history = collections.deque(maxlen=history_length)
        self._action_history = collections.deque(maxlen=history_length)

        self._observation_spec = self._get_observation_spec()
예제 #3
0
파일: wrappers.py 프로젝트: wau/agents
    def __init__(self, env: py_environment.PyEnvironment,
                 idx: Union[Sequence[int], np.ndarray]):
        """Creates an observation filter wrapper.

    Args:
      env: Environment to wrap.
      idx: Array of indexes pointing to elements to include in output.

    Raises:
      ValueError: If observation spec is nested.
      ValueError: If indexes are not single-dimensional.
      ValueError: If no index is provided.
      ValueError: If one of the indexes is out of bounds.
    """
        super(ObservationFilterWrapper, self).__init__(env)
        idx = np.array(idx)
        if tf.nest.is_nested(env.observation_spec()):
            raise ValueError(
                'ObservationFilterWrapper only works with single-array '
                'observations (not nested).')
        if len(idx.shape) != 1:
            raise ValueError('ObservationFilterWrapper only works with '
                             'single-dimensional indexes for filtering.')
        if idx.shape[0] < 1:
            raise ValueError(
                'At least one index needs to be provided for filtering.')
        if not np.all(idx < env.observation_spec().shape[0]):
            raise ValueError('One of the indexes is out of bounds.')

        self._idx = idx
        self._observation_spec = env.observation_spec().replace(
            shape=idx.shape)
예제 #4
0
파일: utils.py 프로젝트: tensorflow/agents
def validate_py_environment(
    environment: py_environment.PyEnvironment,
    episodes: int = 5,
    observation_and_action_constraint_splitter: Optional[
        types.Splitter] = None):
    """Validates the environment follows the defined specs."""
    time_step_spec = environment.time_step_spec()
    action_spec = environment.action_spec()

    random_policy = random_py_policy.RandomPyPolicy(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        observation_and_action_constraint_splitter=(
            observation_and_action_constraint_splitter))

    if environment.batch_size is not None:
        batched_time_step_spec = array_spec.add_outer_dims_nest(
            time_step_spec, outer_dims=(environment.batch_size, ))
    else:
        batched_time_step_spec = time_step_spec

    episode_count = 0
    time_step = environment.reset()

    while episode_count < episodes:
        if not array_spec.check_arrays_nest(time_step, batched_time_step_spec):
            raise ValueError('Given `time_step`: %r does not match expected '
                             '`time_step_spec`: %r' %
                             (time_step, batched_time_step_spec))

        action = random_policy.action(time_step).action
        time_step = environment.step(action)

        episode_count += np.sum(time_step.is_last())
예제 #5
0
파일: wrappers.py 프로젝트: wau/agents
    def __init__(self, env: py_environment.PyEnvironment,
                 num_actions: np.ndarray):
        """Constructs a wrapper for discretizing the action space.

    **Note:** Only environments with a single BoundedArraySpec are supported.

    Args:
      env: Environment to wrap.
      num_actions: A np.array of the same shape as the environment's
        action_spec. Elements in the array specify the number of actions to
        discretize to for each dimension.

    Raises:
      ValueError: IF the action_spec shape and the limits shape are not equal.
    """
        super(ActionDiscretizeWrapper, self).__init__(env)

        action_spec = tf.nest.flatten(env.action_spec())
        if len(action_spec) != 1:
            raise ValueError(
                'ActionDiscretizeWrapper only supports environments with a single '
                'action spec. Got {}'.format(env.action_spec()))

        action_spec = action_spec[0]
        self._original_spec = action_spec
        self._num_actions = np.broadcast_to(num_actions, action_spec.shape)

        if action_spec.shape != self._num_actions.shape:
            raise ValueError(
                'Spec {} and limit shape do not match. Got {}'.format(
                    action_spec, self._num_actions.shape))

        self._discrete_spec, self._action_map = self._discretize_spec(
            action_spec, self._num_actions)
예제 #6
0
def collection_step(env: py_environment.PyEnvironment,
                    tf_agent: agent.DQNAgent,
                    replay_memory: agent.ReplayMemory,
                    reward_tracker: agent.RewardTracker,
                    collect_episodes: int) -> None:
    """Samples transitions with the given Driver."""
    if not collect_episodes:
        return

    for _ in range(collect_episodes):
        # Reset env
        ts = env.reset()
        observations = ts.observation

        episode_reward = 0
        done = False
        while not done:
            action = tf_agent.epsilon_greedy_policy(observations,
                                                    training=True)
            ts = env.step(action)
            next_obs, reward, done = ts.observation, ts.reward, ts.is_last()

            replay_memory.append(
                (observations, action, reward, next_obs, done))
            observations = next_obs

            episode_reward += reward

        reward_tracker.append(episode_reward)

    return episode_reward
예제 #7
0
def create_video(py_env: py_environment.PyEnvironment,
                 tf_env: tf_environment.TFEnvironment,
                 policy: tf_py_policy.TFPyPolicy,
                 num_episodes=10,
                 max_episode_length=60 * 30,
                 video_filename='eval_video.mp4'):
    logging.info('Generating video %s', video_filename)
    py_env.reset()
    with imageio.get_writer(video_filename, fps=60) as vid:
        for episode in range(num_episodes):
            logging.info('\tEpisode %s of %s', episode + 1, num_episodes)

            frames = 0
            time_step = tf_env.reset()
            py_env.reset()
            state = policy.get_initial_state(tf_env.batch_size)

            vid.append_data(py_env.render(mode='rgb_array'))
            while not time_step.is_last() and frames < max_episode_length:
                if frames % 60 == 0:
                    logging.info('Frame %s of %s', frames, max_episode_length)
                policy_step = policy.action(time_step, state)
                state = policy_step.state
                time_step = tf_env.step(policy_step.action)
                py_env.step(policy_step.action)
                vid.append_data(py_env.render(mode='rgb_array'))
                frames += 1
            py_env.close()
    logging.info('Finished rendering video %s', video_filename)
예제 #8
0
파일: wrappers.py 프로젝트: morgandu/agents
    def __init__(self, env: py_environment.PyEnvironment, flat_dtype=None):
        """Creates a FlattenActionWrapper.

    Args:
      env: Environment to wrap.
      flat_dtype: Optional, if set to a np.dtype the flat action_spec uses this
        dtype.

    Raises:
      ValueError: If any of the action_spec shapes ndim > 1.
      ValueError: If dtypes differ across action specs and flat_dtype is not
        set.
    """
        super(FlattenActionWrapper, self).__init__(env)
        self._original_action_spec = env.action_spec()
        flat_action_spec = tf.nest.flatten(env.action_spec())

        if any([len(s.shape) > 1 for s in flat_action_spec]):
            raise ValueError('ActionSpec shapes should all have ndim == 1.')

        if flat_dtype is None and any(
            [s.dtype != flat_action_spec[0].dtype for s in flat_action_spec]):
            raise ValueError(
                'All action_spec dtypes must match, or `flat_dtype` should be set.'
            )

        # shape or 1 to handle scalar shapes ().
        shape = sum([(s.shape and s.shape[0]) or 1 for s in flat_action_spec]),

        if all([
                isinstance(s, array_spec.BoundedArraySpec)
                for s in flat_action_spec
        ]):
            minimums = [
                np.broadcast_to(s.minimum, shape=s.shape)
                for s in flat_action_spec
            ]
            maximums = [
                np.broadcast_to(s.maximum, shape=s.shape)
                for s in flat_action_spec
            ]

            minimum = np.hstack(minimums)
            maximum = np.hstack(maximums)
            self._action_spec = array_spec.BoundedArraySpec(
                shape=shape,
                dtype=flat_dtype or flat_action_spec[0].dtype,
                minimum=minimum,
                maximum=maximum,
                name='FlattenedActionSpec')
        else:
            self._action_spec = array_spec.ArraySpec(
                shape=shape,
                dtype=flat_dtype or flat_action_spec[0].dtype,
                name='FlattenedActionSpec')

        self._flat_action_spec = flat_action_spec
예제 #9
0
def agent_play_episode(env: py_environment.PyEnvironment, agent: DQNAgent) -> None:
    time_step = env.reset()

    plt.figure(figsize=(11, 6), dpi=200)
    i = 1
    ax = plt.subplot(4, 8, i)
    render_time_step(time_step, ax)
    while not time_step.is_last():
        action = agent.greedy_policy(time_step.observation)
        time_step = env.step(action)
        i += 1
        ax = plt.subplot(4, 8, i)
        render_time_step(time_step, ax, action)

    plt.tight_layout()
    plt.show()
예제 #10
0
    def __init__(
        self, evaluate: Callable[[np.ndarray], np.float] = default_eval_method
    ):
        PyEnvironment.__init__(self)

        self._action_spec = array_spec.BoundedArraySpec(
            shape=(), dtype=np.int32, minimum=0, maximum=_N_ACTIONS - 1, name="action"
        )
        self._observation_spec = array_spec.BoundedArraySpec(
            shape=(GRID_SIZE, GRID_SIZE, 1),
            dtype=np.float32,
            minimum=0,
            name="observation",
        )
        self._evaluate = default_eval_method
        self.render_size = 320
예제 #11
0
def create_video(py_environment: PyEnvironment,
                 tf_environment: TFPyEnvironment,
                 policy: tf_policy,
                 num_episodes=10,
                 video_filename='imageio.mp4'):
    print("Generating video %s" % video_filename)
    with imageio.get_writer(video_filename, fps=60) as video:
        for episode in range(num_episodes):
            print("Generating episode %d of %d" % (episode, num_episodes))

            time_step = tf_environment.reset()
            video.append_data(py_environment.render())
            while not time_step.is_last():
                action_step = policy.action(time_step)

                time_step = tf_environment.step(action_step.action)
                video.append_data(py_environment.render())
def create_video(py_environment: PyEnvironment, tf_environment: TFEnvironment, policy: tf_policy.Base, num_episodes=10, video_filename='imageio.mp4'):
	logging.info("Generating video %s" % video_filename)
	with imageio.get_writer(video_filename, fps=60) as video:
		for episode in range(num_episodes):
			logging.info("Generating episode %d of %d" % (episode, num_episodes))

			time_step = tf_environment.reset()
			state = policy.get_initial_state(tf_environment.batch_size)

			video.append_data(py_environment.render())
			while not time_step.is_last():
				policy_step: PolicyStep = policy.action(time_step, state)
				state = policy_step.state
				time_step = tf_environment.step(policy_step.action)
				video.append_data(py_environment.render())

	logging.info("Finished video %s" % video_filename)
예제 #13
0
def create_video(py_environment: PyEnvironment,
                 tf_environment: TFPyEnvironment,
                 policy: tf_policy,
                 num_episodes=10,
                 video_filename='imageio.mp4'):
    print("Generating video %s" % video_filename)
    with imageio.get_writer(video_filename, fps=60) as video:
        for episode in range(num_episodes):
            episode_return = 0.0
            time_step = tf_environment.reset()
            video.append_data(py_environment.render())
            while not time_step.is_last():
                action_step = policy.action(time_step)
                time_step = tf_environment.step(action_step.action)
                episode_return += time_step.reward
                video.append_data(py_environment.render())
            print(
                f"Generated episode {episode} of {num_episodes}. Return:{episode_return} "
            )
예제 #14
0
파일: wrappers.py 프로젝트: wau/agents
    def __init__(self, env: py_environment.PyEnvironment,
                 num_extra_actions: int):
        """Initializes an instance of `ExtraDisabledActionsWrapper`.

    Args:
      env: The environment to wrap.
      num_extra_actions: The number of extra actions to add.
    """
        super(ExtraDisabledActionsWrapper, self).__init__(env)
        orig_action_spec = env.action_spec()
        self._action_spec = array_spec.BoundedArraySpec(
            shape=orig_action_spec.shape,
            dtype=orig_action_spec.dtype,
            minimum=orig_action_spec.minimum,
            maximum=orig_action_spec.maximum + num_extra_actions)
        mask_spec = array_spec.ArraySpec(
            shape=[self._action_spec.maximum - self._action_spec.minimum + 1],
            dtype=np.int64)
        self._masked_observation_spec = (env.observation_spec(), mask_spec)
        self._constant_mask = np.array(
            [[1] * (orig_action_spec.maximum - orig_action_spec.minimum + 1) +
             [0] * num_extra_actions] * self.batch_size)
예제 #15
0
def create_obs_stacker(environment: py_environment.PyEnvironment,
                       history_size: int = 3):
    """Creates an observation stacker.

    Args:
      environment: Gathering object.
      history_size: int, number of steps to stack.

    Returns:
      An observation stacker object.
    """

    return ObservationStacker(history_size, environment.single_obs_shape())
예제 #16
0
def validate_py_environment(environment: py_environment.PyEnvironment,
                            episodes: int = 5):
    """Validates the environment follows the defined specs."""
    time_step_spec = environment.time_step_spec()
    action_spec = environment.action_spec()

    random_policy = random_py_policy.RandomPyPolicy(
        time_step_spec=time_step_spec, action_spec=action_spec)

    episode_count = 0
    time_step = environment.reset()

    while episode_count < episodes:
        if not array_spec.check_arrays_nest(time_step, time_step_spec):
            raise ValueError(
                'Given `time_step`: %r does not match expected `time_step_spec`: %r'
                % (time_step, time_step_spec))

        action = random_policy.action(time_step).action
        time_step = environment.step(action)

        if time_step.is_last():
            episode_count += 1
            time_step = environment.reset()
예제 #17
0
파일: wrappers.py 프로젝트: wau/agents
    def __init__(self,
                 env: py_environment.PyEnvironment,
                 observations_allowlist: Optional[Sequence[Text]] = None):
        """Initializes a wrapper to flatten environment observations.

    Args:
      env: A `py_environment.PyEnvironment` environment to wrap.
      observations_allowlist: A list of observation keys that want to be
        observed from the environment.  All other observations returned are
        filtered out.  If not provided, all observations will be kept.
        Additionally, if this is provided, the environment is expected to return
        a dictionary of observations.

    Raises:
      ValueError: If the current environment does not return a dictionary of
        observations and observations_allowlist is provided.
      ValueError: If the observation_allowlist keys are not found in the
        environment.
    """
        super(FlattenObservationsWrapper, self).__init__(env)

        # If observations allowlist is provided:
        #  Check that the environment returns a dictionary of observations.
        #  Check that the set of allowed keys is a found in the environment keys.
        if observations_allowlist is not None:
            if not isinstance(env.observation_spec(), dict):
                raise ValueError(
                    'If you provide an observations allowlist, the current environment '
                    'must return a dictionary of observations! The returned observation'
                    ' spec is type %s.' % (type(env.observation_spec())))

            # Check that observation allowlist keys are valid observation keys.
            if not (set(observations_allowlist).issubset(
                    env.observation_spec().keys())):
                raise ValueError(
                    'The observation allowlist contains keys not found in the '
                    'environment! Unknown keys: %s' % list(
                        set(observations_allowlist).difference(
                            env.observation_spec().keys())))

        # Check that all observations have the same dtype. This dtype will be used
        # to create the flattened ArraySpec.
        env_dtypes = list(
            set([obs.dtype for obs in env.observation_spec().values()]))
        if len(env_dtypes) != 1:
            raise ValueError(
                'The observation spec must all have the same dtypes! '
                'Currently found dtypes: %s' % (env_dtypes))
        inferred_spec_dtype = env_dtypes[0]

        self._observation_spec_dtype = inferred_spec_dtype
        self._observations_allowlist = observations_allowlist
        # Update the observation spec in the environment.
        observations_spec = env.observation_spec()
        if self._observations_allowlist is not None:
            observations_spec = self._filter_observations(observations_spec)

        # Compute the observation length after flattening the observation items and
        # nested structure. Observation specs are not batched.
        observation_total_len = sum(
            int(np.prod(observation.shape))
            for observation in self._flatten_nested_observations(
                observations_spec, is_batched=False))

        # Update the observation spec as an array of one-dimension.
        self._flattened_observation_spec = array_spec.ArraySpec(
            shape=(observation_total_len, ),
            dtype=self._observation_spec_dtype,
            name='packed_observations')