Example #1
0
    def _testDynamicUnrollResetsStateOnReset(self, cell_type):
        cell = cell_type()
        batch_size = 4
        max_time = 7
        inputs = tf.random_uniform((batch_size, max_time, 1))
        reset_mask = (tf.random_normal((batch_size, max_time)) > 0)

        outputs, final_state, _ = rnn_utils.dynamic_unroll(cell,
                                                           inputs,
                                                           reset_mask,
                                                           dtype=tf.float32)

        nest.assert_same_structure(outputs, cell.output_size)
        nest.assert_same_structure(final_state, cell.state_size)

        reset_mask, inputs, outputs, final_state = self.evaluate(
            (reset_mask, inputs, outputs, final_state))

        self.assertAllClose(outputs[:, -1, :], final_state)

        # outputs will contain cumulative sums up until a reset
        expected_outputs = []
        state = np.zeros_like(final_state)
        for i, frame in enumerate(np.transpose(inputs, [1, 0, 2])):
            state = state * np.reshape(~reset_mask[:, i], state.shape) + frame
            expected_outputs.append(np.array(state))
        expected_outputs = np.transpose(expected_outputs, [1, 0, 2])
        self.assertAllClose(outputs, expected_outputs)
Example #2
0
    def call(self, observation, action, step_type, network_state=None):
        num_outer_dims = nest_utils.get_outer_rank(observation,
                                                   self.observation_spec)
        if num_outer_dims not in (1, 2):
            raise ValueError(
                'Input observation must have a batch or batch x time outer shape.'
            )

        has_time_dim = num_outer_dims == 2
        if not has_time_dim:
            # Add a time dimension to the inputs.
            observation = nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                             observation)
            action = nest.map_structure(lambda t: tf.expand_dims(t, 1), action)
            step_type = nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                           step_type)

        observation = tf.to_float(nest.flatten(observation)[0])
        action = tf.to_float(nest.flatten(action)[0])

        batch_squash = utils.BatchSquash(2)  # Squash B, and T dims.
        observation = batch_squash.flatten(
            observation)  # [B, T, ...] -> [BxT, ...]
        action = batch_squash.flatten(action)

        for layer in self._observation_layers:
            observation = layer(observation)

        for layer in self._action_layers:
            action = layer(action)

        joint = tf.concat([observation, action], -1)
        for layer in self._joint_layers:
            joint = layer(joint)

        joint = batch_squash.unflatten(joint)  # [B x T, ...] -> [B, T, ...]

        with tf.name_scope('reset_mask'):
            reset_mask = tf.equal(step_type, time_step.StepType.FIRST)
        # Unroll over the time sequence.
        joint, network_state, _ = rnn_utils.dynamic_unroll(
            self._cell,
            joint,
            reset_mask,
            initial_state=network_state,
            dtype=tf.float32)

        output = batch_squash.flatten(joint)  # [B, T, ...] -> [B x T, ...]

        for layer in self._output_layers:
            output = layer(output)

        q_value = tf.reshape(output, [-1])
        q_value = batch_squash.unflatten(
            q_value)  # [B x T, ...] -> [B, T, ...]
        if not has_time_dim:
            q_value = tf.squeeze(q_value, axis=1)

        return q_value, network_state
Example #3
0
    def call(self, observation, step_type, network_state=None):
        num_outer_dims = nest_utils.get_outer_rank(observation,
                                                   self.observation_spec)
        if num_outer_dims not in (1, 2):
            raise ValueError(
                'Input observation must have a batch or batch x time outer shape.'
            )

        has_time_dim = num_outer_dims == 2
        if not has_time_dim:
            # Add a time dimension to the inputs.
            observation = nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                             observation)
            step_type = nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                           step_type)

        states = tf.cast(nest.flatten(observation)[0], tf.float32)
        batch_squash = utils.BatchSquash(2)  # Squash B, and T dims.
        states = batch_squash.flatten(states)  # [B, T, ...] -> [B x T, ...]

        for layer in self._input_layers:
            states = layer(states)

        states = batch_squash.unflatten(states)  # [B x T, ...] -> [B, T, ...]

        with tf.name_scope('reset_mask'):
            reset_mask = tf.equal(step_type, time_step.StepType.FIRST)
        # Unroll over the time sequence.
        states, network_state, _ = rnn_utils.dynamic_unroll(
            self._cell,
            states,
            reset_mask,
            initial_state=network_state,
            dtype=tf.float32)

        states = batch_squash.flatten(states)  # [B, T, ...] -> [B x T, ...]

        for layer in self._output_layers:
            states = layer(states)

        actions = []
        for layer, spec in zip(self._action_layers, self._flat_action_spec):
            action = layer(states)
            action = common_utils.scale_to_spec(action, spec)
            action = batch_squash.unflatten(
                action)  # [B x T, ...] -> [B, T, ...]
            if not has_time_dim:
                action = tf.squeeze(action, axis=1)
            actions.append(action)

        return nest.pack_sequence_as(self._action_spec, actions), network_state
Example #4
0
    def call(self, observation, step_type, network_state=None):
        num_outer_dims = nest_utils.get_outer_rank(observation,
                                                   self._observation_spec)
        if num_outer_dims not in (1, 2):
            raise ValueError(
                'Input observation must have a batch or batch x time outer shape.'
            )

        has_time_dim = num_outer_dims == 2
        if not has_time_dim:
            # Add a time dimension to the inputs.
            observation = nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                             observation)
            step_type = nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                           step_type)

        state = tf.to_float(nest.flatten(observation)[0])

        num_feature_dims = 3 if self._conv_layer_params else 1
        state.shape.with_rank_at_least(num_feature_dims)
        batch_squash = utils.BatchSquash(state.shape.ndims - num_feature_dims)

        state = batch_squash.flatten(state)
        state, network_state = self._input_encoder(state, step_type,
                                                   network_state)
        state = batch_squash.unflatten(state)

        with tf.name_scope('reset_mask'):
            reset_mask = tf.equal(step_type, time_step.StepType.FIRST)
        # Unroll over the time sequence.
        state, network_state, _ = rnn_utils.dynamic_unroll(
            self._cell,
            state,
            reset_mask,
            initial_state=network_state,
            dtype=tf.float32)

        state = batch_squash.flatten(state)
        for layer in self._output_encoder:
            state = layer(state)
        state = batch_squash.unflatten(state)

        if not has_time_dim:
            # Remove time dimension from the state.
            state = tf.squeeze(state, [1])

        return state, network_state
    def call(self, observation, step_type, network_state=None):
        num_outer_dims = nest_utils.get_outer_rank(observation,
                                                   self.observation_spec)
        if num_outer_dims not in (1, 2):
            raise ValueError(
                'Input observation must have a batch or batch x time outer shape.'
            )

        has_time_dim = num_outer_dims == 2
        if not has_time_dim:
            # Add a time dimension to the inputs.
            observation = nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                             observation)
            step_type = nest.map_structure(lambda t: tf.expand_dims(t, 1),
                                           step_type)

        states = tf.to_float(nest.flatten(observation)[0])
        batch_squash = utils.BatchSquash(2)  # Squash B, and T dims.
        states = batch_squash.flatten(states)

        for layer in self._input_layers:
            states = layer(states)

        states = batch_squash.unflatten(states)

        with tf.name_scope('reset_mask'):
            reset_mask = tf.equal(step_type, time_step.StepType.FIRST)
        # Unroll over the time sequence.
        states, network_state, _ = rnn_utils.dynamic_unroll(
            self._cell,
            states,
            reset_mask,
            initial_state=network_state,
            dtype=tf.float32)

        states = batch_squash.flatten(states)

        for layer in self._output_layers:
            states = layer(states)

        states = batch_squash.unflatten(states)
        outputs = [
            projection(states, num_outer_dims)
            for projection in self._projection_networks
        ]

        return nest.pack_sequence_as(self._action_spec, outputs), network_state
Example #6
0
 def testDynamicUnrollMatchesDynamicRNNWhenNoResetSingleTimeStep(self):
     cell = tf.nn.rnn_cell.LSTMCell(3)
     batch_size = 4
     max_time = 1
     inputs = tf.random_uniform((batch_size, max_time, 2), dtype=tf.float32)
     reset_mask = tf.zeros((batch_size, max_time), dtype=tf.bool)
     outputs_dun, final_state_dun, _ = rnn_utils.dynamic_unroll(
         cell, inputs, reset_mask, dtype=tf.float32)
     outputs_drnn, final_state_drnn = tf.nn.dynamic_rnn(cell,
                                                        inputs,
                                                        dtype=tf.float32)
     self.evaluate(tf.global_variables_initializer())
     outputs_dun, final_state_dun, outputs_drnn, final_state_drnn = (
         self.evaluate((outputs_dun, final_state_dun, outputs_drnn,
                        final_state_drnn)))
     self.assertAllClose(outputs_dun, outputs_drnn)
     self.assertAllClose(final_state_dun, final_state_drnn)