Пример #1
0
    def test_simple(self):
        def f(a, b):
            return tf.reduce_sum(a, axis=-1), tf.reduce_max(b, axis=-1)

        a_sum, b_max = utils.batch_apply(f, (
            tf.constant([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]),
            tf.constant([[[8, 9], [10, 11]], [[12, 13], [14, 15]]]),
        ))
        self.assertAllEqual(tf.constant([[1, 5], [9, 13]]), a_sum)
        self.assertAllEqual(tf.constant([[9, 11], [13, 15]]), b_max)
Пример #2
0
    def _unroll(self, prev_actions, env_outputs, core_state):
        unused_reward, done, unused_observation, _, _ = env_outputs

        torso_outputs = utils.batch_apply(self._torso,
                                          (prev_actions, env_outputs))

        initial_core_state = self._core.get_initial_state(
            batch_size=tf.shape(prev_actions)[1], dtype=tf.float32)
        core_output_list = []
        for input_, d in zip(tf.unstack(torso_outputs), tf.unstack(done)):
            # If the episode ended, the core state should be reset before the next.
            core_state = tf.nest.map_structure(lambda x, y, d=d: tf.where(
                tf.reshape(d, [d.shape[0]] + [1] * (x.shape.rank - 1)), x, y),
                                               initial_core_state,
                                               core_state)
            core_output, core_state = self._core(input_, core_state)
            core_output_list.append(core_output)
        core_outputs = tf.stack(core_output_list)

        return utils.batch_apply(self._head, (core_outputs, )), core_state
Пример #3
0
  def _unroll(self, prev_actions, env_outputs, agent_state):
    # [time, batch_size, <field shape>]
    unused_reward, done, observation = env_outputs
    observation = tf.cast(observation, tf.float32)

    initial_agent_state = self.initial_state(batch_size=tf.shape(done)[1])

    stacked_frames, frame_state = stack_frames(
        observation, agent_state.frame_stacking_state, done, self._stack_size)

    env_outputs = env_outputs._replace(observation=stacked_frames / 255)
    # [time, batch_size, torso_output_size]
    torso_outputs = utils.batch_apply(self._torso, (prev_actions, env_outputs))

    core_outputs, core_state = _unroll_cell(
        torso_outputs, done, agent_state.core_state,
        initial_agent_state.core_state,
        self._core)

    agent_output = utils.batch_apply(self._head, (core_outputs,))
    return agent_output, AgentState(core_state, frame_state)
Пример #4
0
  def __call__(self, input_, core_state, unroll=False, is_training=False):
    """Applies the network.

    Args:
      input_: A pair (prev_actions: <int32>[batch_size], env_outputs: EnvOutput
        structure where each tensor has a [batch_size] front dimension). When
        unroll is True, an unroll (sequence of transitions) is expected, and
        those tensors are expected to have [time, batch_size] front dimensions.
      core_state: Opaque (batched) recurrent state structure corresponding to
        the beginning of the input sequence of transitions.
      unroll: Whether the input is an unroll (sequence of transitions) or just a
        single (batched) transition.
      is_training: Enables normalization statistics updates (when unroll is
        True).

    Returns:
      A pair:
        - agent_output: AgentOutput structure. Tensors have front dimensions
          [batch_size] or [time, batch_size] depending on the value of 'unroll'.
        - core_state: Opaque (batched) recurrent state structure.
    """
    _, env_outputs = input_

    # We first handle initializing and (potentially) updating normalization
    # statistics.  We only update during the gradient update steps.
    # `is_training` is slightly misleading as it is also True during inference
    # steps in the training phase. We hence also require unroll=True which
    # indicates gradient updates.
    training_model_update = is_training and unroll
    data = env_outputs[2]
    if not self.observation_normalizer.initialized:
      if training_model_update:
        raise ValueError('It seems unlikely that stats should be updated in the'
                         ' same call where the stats are initialized.')
      self.observation_normalizer.init_normalization_stats(data.shape[-1])

    if self._rnn is not None:

      if unroll:
        representations = utils.batch_apply(self._flat_apply_pre_lstm,
                                            (env_outputs,))
        representations, core_state = self._apply_rnn(
            representations, core_state, env_outputs.done)
        outputs = utils.batch_apply(self._flat_apply_post_lstm,
                                    (representations,))
      else:
        representations = self._flat_apply_pre_lstm(env_outputs)
        representations, done = tf.nest.map_structure(
            lambda t: tf.expand_dims(t, 0),
            (representations, env_outputs.done))
        representations, core_state = self._apply_rnn(
            representations, core_state, done)
        representations = tf.nest.map_structure(
            lambda t: tf.squeeze(t, 0), representations)
        outputs = self._flat_apply_post_lstm(representations)
    else:
      # Simplify.
      if unroll:
        outputs = utils.batch_apply(self._flat_apply_no_lstm, (env_outputs,))
      else:
        outputs = self._flat_apply_no_lstm(env_outputs)

    return outputs, core_state
Пример #5
0
 def _unroll(self, prev_actions, env_outputs, core_state):
     torso_outputs = utils.batch_apply(self._torso,
                                       (prev_actions, env_outputs))
     return utils.batch_apply(self._head, (torso_outputs, )), core_state