def test_simple(self): def f(a, b): return tf.reduce_sum(a, axis=-1), tf.reduce_max(b, axis=-1) a_sum, b_max = utils.batch_apply(f, ( tf.constant([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]), tf.constant([[[8, 9], [10, 11]], [[12, 13], [14, 15]]]), )) self.assertAllEqual(tf.constant([[1, 5], [9, 13]]), a_sum) self.assertAllEqual(tf.constant([[9, 11], [13, 15]]), b_max)
def _unroll(self, prev_actions, env_outputs, core_state): unused_reward, done, unused_observation, _, _ = env_outputs torso_outputs = utils.batch_apply(self._torso, (prev_actions, env_outputs)) initial_core_state = self._core.get_initial_state( batch_size=tf.shape(prev_actions)[1], dtype=tf.float32) core_output_list = [] for input_, d in zip(tf.unstack(torso_outputs), tf.unstack(done)): # If the episode ended, the core state should be reset before the next. core_state = tf.nest.map_structure(lambda x, y, d=d: tf.where( tf.reshape(d, [d.shape[0]] + [1] * (x.shape.rank - 1)), x, y), initial_core_state, core_state) core_output, core_state = self._core(input_, core_state) core_output_list.append(core_output) core_outputs = tf.stack(core_output_list) return utils.batch_apply(self._head, (core_outputs, )), core_state
def _unroll(self, prev_actions, env_outputs, agent_state): # [time, batch_size, <field shape>] unused_reward, done, observation = env_outputs observation = tf.cast(observation, tf.float32) initial_agent_state = self.initial_state(batch_size=tf.shape(done)[1]) stacked_frames, frame_state = stack_frames( observation, agent_state.frame_stacking_state, done, self._stack_size) env_outputs = env_outputs._replace(observation=stacked_frames / 255) # [time, batch_size, torso_output_size] torso_outputs = utils.batch_apply(self._torso, (prev_actions, env_outputs)) core_outputs, core_state = _unroll_cell( torso_outputs, done, agent_state.core_state, initial_agent_state.core_state, self._core) agent_output = utils.batch_apply(self._head, (core_outputs,)) return agent_output, AgentState(core_state, frame_state)
def __call__(self, input_, core_state, unroll=False, is_training=False): """Applies the network. Args: input_: A pair (prev_actions: <int32>[batch_size], env_outputs: EnvOutput structure where each tensor has a [batch_size] front dimension). When unroll is True, an unroll (sequence of transitions) is expected, and those tensors are expected to have [time, batch_size] front dimensions. core_state: Opaque (batched) recurrent state structure corresponding to the beginning of the input sequence of transitions. unroll: Whether the input is an unroll (sequence of transitions) or just a single (batched) transition. is_training: Enables normalization statistics updates (when unroll is True). Returns: A pair: - agent_output: AgentOutput structure. Tensors have front dimensions [batch_size] or [time, batch_size] depending on the value of 'unroll'. - core_state: Opaque (batched) recurrent state structure. """ _, env_outputs = input_ # We first handle initializing and (potentially) updating normalization # statistics. We only update during the gradient update steps. # `is_training` is slightly misleading as it is also True during inference # steps in the training phase. We hence also require unroll=True which # indicates gradient updates. training_model_update = is_training and unroll data = env_outputs[2] if not self.observation_normalizer.initialized: if training_model_update: raise ValueError('It seems unlikely that stats should be updated in the' ' same call where the stats are initialized.') self.observation_normalizer.init_normalization_stats(data.shape[-1]) if self._rnn is not None: if unroll: representations = utils.batch_apply(self._flat_apply_pre_lstm, (env_outputs,)) representations, core_state = self._apply_rnn( representations, core_state, env_outputs.done) outputs = utils.batch_apply(self._flat_apply_post_lstm, (representations,)) else: representations = self._flat_apply_pre_lstm(env_outputs) representations, done = tf.nest.map_structure( lambda t: tf.expand_dims(t, 0), (representations, env_outputs.done)) representations, core_state = self._apply_rnn( representations, core_state, done) representations = tf.nest.map_structure( lambda t: tf.squeeze(t, 0), representations) outputs = self._flat_apply_post_lstm(representations) else: # Simplify. if unroll: outputs = utils.batch_apply(self._flat_apply_no_lstm, (env_outputs,)) else: outputs = self._flat_apply_no_lstm(env_outputs) return outputs, core_state
def _unroll(self, prev_actions, env_outputs, core_state): torso_outputs = utils.batch_apply(self._torso, (prev_actions, env_outputs)) return utils.batch_apply(self._head, (torso_outputs, )), core_state