def _testDynamicUnrollResetsStateOnReset(self, cell_type): cell = cell_type() batch_size = 4 max_time = 7 inputs = tf.random_uniform((batch_size, max_time, 1)) reset_mask = (tf.random_normal((batch_size, max_time)) > 0) outputs, final_state, _ = rnn_utils.dynamic_unroll(cell, inputs, reset_mask, dtype=tf.float32) nest.assert_same_structure(outputs, cell.output_size) nest.assert_same_structure(final_state, cell.state_size) reset_mask, inputs, outputs, final_state = self.evaluate( (reset_mask, inputs, outputs, final_state)) self.assertAllClose(outputs[:, -1, :], final_state) # outputs will contain cumulative sums up until a reset expected_outputs = [] state = np.zeros_like(final_state) for i, frame in enumerate(np.transpose(inputs, [1, 0, 2])): state = state * np.reshape(~reset_mask[:, i], state.shape) + frame expected_outputs.append(np.array(state)) expected_outputs = np.transpose(expected_outputs, [1, 0, 2]) self.assertAllClose(outputs, expected_outputs)
def call(self, observation, action, step_type, network_state=None): num_outer_dims = nest_utils.get_outer_rank(observation, self.observation_spec) if num_outer_dims not in (1, 2): raise ValueError( 'Input observation must have a batch or batch x time outer shape.' ) has_time_dim = num_outer_dims == 2 if not has_time_dim: # Add a time dimension to the inputs. observation = nest.map_structure(lambda t: tf.expand_dims(t, 1), observation) action = nest.map_structure(lambda t: tf.expand_dims(t, 1), action) step_type = nest.map_structure(lambda t: tf.expand_dims(t, 1), step_type) observation = tf.to_float(nest.flatten(observation)[0]) action = tf.to_float(nest.flatten(action)[0]) batch_squash = utils.BatchSquash(2) # Squash B, and T dims. observation = batch_squash.flatten( observation) # [B, T, ...] -> [BxT, ...] action = batch_squash.flatten(action) for layer in self._observation_layers: observation = layer(observation) for layer in self._action_layers: action = layer(action) joint = tf.concat([observation, action], -1) for layer in self._joint_layers: joint = layer(joint) joint = batch_squash.unflatten(joint) # [B x T, ...] -> [B, T, ...] with tf.name_scope('reset_mask'): reset_mask = tf.equal(step_type, time_step.StepType.FIRST) # Unroll over the time sequence. joint, network_state, _ = rnn_utils.dynamic_unroll( self._cell, joint, reset_mask, initial_state=network_state, dtype=tf.float32) output = batch_squash.flatten(joint) # [B, T, ...] -> [B x T, ...] for layer in self._output_layers: output = layer(output) q_value = tf.reshape(output, [-1]) q_value = batch_squash.unflatten( q_value) # [B x T, ...] -> [B, T, ...] if not has_time_dim: q_value = tf.squeeze(q_value, axis=1) return q_value, network_state
def call(self, observation, step_type, network_state=None): num_outer_dims = nest_utils.get_outer_rank(observation, self.observation_spec) if num_outer_dims not in (1, 2): raise ValueError( 'Input observation must have a batch or batch x time outer shape.' ) has_time_dim = num_outer_dims == 2 if not has_time_dim: # Add a time dimension to the inputs. observation = nest.map_structure(lambda t: tf.expand_dims(t, 1), observation) step_type = nest.map_structure(lambda t: tf.expand_dims(t, 1), step_type) states = tf.cast(nest.flatten(observation)[0], tf.float32) batch_squash = utils.BatchSquash(2) # Squash B, and T dims. states = batch_squash.flatten(states) # [B, T, ...] -> [B x T, ...] for layer in self._input_layers: states = layer(states) states = batch_squash.unflatten(states) # [B x T, ...] -> [B, T, ...] with tf.name_scope('reset_mask'): reset_mask = tf.equal(step_type, time_step.StepType.FIRST) # Unroll over the time sequence. states, network_state, _ = rnn_utils.dynamic_unroll( self._cell, states, reset_mask, initial_state=network_state, dtype=tf.float32) states = batch_squash.flatten(states) # [B, T, ...] -> [B x T, ...] for layer in self._output_layers: states = layer(states) actions = [] for layer, spec in zip(self._action_layers, self._flat_action_spec): action = layer(states) action = common_utils.scale_to_spec(action, spec) action = batch_squash.unflatten( action) # [B x T, ...] -> [B, T, ...] if not has_time_dim: action = tf.squeeze(action, axis=1) actions.append(action) return nest.pack_sequence_as(self._action_spec, actions), network_state
def call(self, observation, step_type, network_state=None): num_outer_dims = nest_utils.get_outer_rank(observation, self._observation_spec) if num_outer_dims not in (1, 2): raise ValueError( 'Input observation must have a batch or batch x time outer shape.' ) has_time_dim = num_outer_dims == 2 if not has_time_dim: # Add a time dimension to the inputs. observation = nest.map_structure(lambda t: tf.expand_dims(t, 1), observation) step_type = nest.map_structure(lambda t: tf.expand_dims(t, 1), step_type) state = tf.to_float(nest.flatten(observation)[0]) num_feature_dims = 3 if self._conv_layer_params else 1 state.shape.with_rank_at_least(num_feature_dims) batch_squash = utils.BatchSquash(state.shape.ndims - num_feature_dims) state = batch_squash.flatten(state) state, network_state = self._input_encoder(state, step_type, network_state) state = batch_squash.unflatten(state) with tf.name_scope('reset_mask'): reset_mask = tf.equal(step_type, time_step.StepType.FIRST) # Unroll over the time sequence. state, network_state, _ = rnn_utils.dynamic_unroll( self._cell, state, reset_mask, initial_state=network_state, dtype=tf.float32) state = batch_squash.flatten(state) for layer in self._output_encoder: state = layer(state) state = batch_squash.unflatten(state) if not has_time_dim: # Remove time dimension from the state. state = tf.squeeze(state, [1]) return state, network_state
def call(self, observation, step_type, network_state=None): num_outer_dims = nest_utils.get_outer_rank(observation, self.observation_spec) if num_outer_dims not in (1, 2): raise ValueError( 'Input observation must have a batch or batch x time outer shape.' ) has_time_dim = num_outer_dims == 2 if not has_time_dim: # Add a time dimension to the inputs. observation = nest.map_structure(lambda t: tf.expand_dims(t, 1), observation) step_type = nest.map_structure(lambda t: tf.expand_dims(t, 1), step_type) states = tf.to_float(nest.flatten(observation)[0]) batch_squash = utils.BatchSquash(2) # Squash B, and T dims. states = batch_squash.flatten(states) for layer in self._input_layers: states = layer(states) states = batch_squash.unflatten(states) with tf.name_scope('reset_mask'): reset_mask = tf.equal(step_type, time_step.StepType.FIRST) # Unroll over the time sequence. states, network_state, _ = rnn_utils.dynamic_unroll( self._cell, states, reset_mask, initial_state=network_state, dtype=tf.float32) states = batch_squash.flatten(states) for layer in self._output_layers: states = layer(states) states = batch_squash.unflatten(states) outputs = [ projection(states, num_outer_dims) for projection in self._projection_networks ] return nest.pack_sequence_as(self._action_spec, outputs), network_state
def testDynamicUnrollMatchesDynamicRNNWhenNoResetSingleTimeStep(self): cell = tf.nn.rnn_cell.LSTMCell(3) batch_size = 4 max_time = 1 inputs = tf.random_uniform((batch_size, max_time, 2), dtype=tf.float32) reset_mask = tf.zeros((batch_size, max_time), dtype=tf.bool) outputs_dun, final_state_dun, _ = rnn_utils.dynamic_unroll( cell, inputs, reset_mask, dtype=tf.float32) outputs_drnn, final_state_drnn = tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32) self.evaluate(tf.global_variables_initializer()) outputs_dun, final_state_dun, outputs_drnn, final_state_drnn = ( self.evaluate((outputs_dun, final_state_dun, outputs_drnn, final_state_drnn))) self.assertAllClose(outputs_dun, outputs_drnn) self.assertAllClose(final_state_dun, final_state_drnn)