예제 #1
0
  def _decode(self, z, helper, input_shape, max_length=None):
    """Decodes the given batch of latent vectors vectors, which may be 0-length.

    Args:
      z: Batch of latent vectors, sized `[batch_size, z_size]`, where `z_size`
        may be 0 for unconditioned decoding.
      helper: A seq2seq.Helper to use.
      input_shape: The shape of each model input vector passed to the decoder.
      max_length: (Optional) The maximum iterations to decode.

    Returns:
      results: The LstmDecodeResults.
    """
    initial_state = lstm_utils.initial_cell_state_from_embedding(
        self._dec_cell, z, name='decoder/z_to_initial_state')

    decoder = lstm_utils.Seq2SeqLstmDecoder(
        self._dec_cell,
        helper,
        initial_state=initial_state,
        input_shape=input_shape,
        output_layer=self._output_layer)
    final_output, final_state, final_lengths = contrib_seq2seq.dynamic_decode(
        decoder,
        maximum_iterations=max_length,
        swap_memory=True,
        scope='decoder')
    results = lstm_utils.LstmDecodeResults(
        rnn_input=final_output.rnn_input[:, :, :self._output_depth],
        rnn_output=final_output.rnn_output,
        samples=final_output.sample_id,
        final_state=final_state,
        final_sequence_lengths=final_lengths)

    return results
예제 #2
0
    def _testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN(  # pylint:disable=invalid-name
            self, use_sequence_length):
        sequence_length = [3, 4, 3, 1, 0]
        batch_size = 5
        max_time = 8
        input_depth = 7
        cell_depth = 10
        max_out = max(sequence_length)

        with self.session(use_gpu=True) as sess:
            inputs = np.random.randn(batch_size, max_time,
                                     input_depth).astype(np.float32)

            cell = tf.nn.rnn_cell.LSTMCell(cell_depth)
            zero_state = cell.zero_state(dtype=tf.float32,
                                         batch_size=batch_size)
            helper = seq2seq.TrainingHelper(inputs, sequence_length)
            my_decoder = seq2seq.BasicDecoder(cell=cell,
                                              helper=helper,
                                              initial_state=zero_state)

            # Match the variable scope of dynamic_rnn below so we end up
            # using the same variables
            with tf.variable_scope("root") as scope:
                final_decoder_outputs, final_decoder_state, _ = seq2seq.dynamic_decode(
                    my_decoder,
                    # impute_finished=True ensures outputs and final state
                    # match those of dynamic_rnn called with sequence_length not None
                    impute_finished=use_sequence_length,
                    scope=scope)

            with tf.variable_scope(scope, reuse=True) as scope:
                final_rnn_outputs, final_rnn_state = tf.nn.dynamic_rnn(
                    cell,
                    inputs,
                    sequence_length=sequence_length
                    if use_sequence_length else None,
                    initial_state=zero_state,
                    scope=scope)

            sess.run(tf.global_variables_initializer())
            sess_results = sess.run({
                "final_decoder_outputs": final_decoder_outputs,
                "final_decoder_state": final_decoder_state,
                "final_rnn_outputs": final_rnn_outputs,
                "final_rnn_state": final_rnn_state
            })

            # Decoder only runs out to max_out; ensure values are identical
            # to dynamic_rnn, which also zeros out outputs and passes along state.
            self.assertAllClose(
                sess_results["final_decoder_outputs"].rnn_output,
                sess_results["final_rnn_outputs"][:, 0:max_out, :])
            if use_sequence_length:
                self.assertAllClose(sess_results["final_decoder_state"],
                                    sess_results["final_rnn_state"])
  def _get_state(self,
                 inputs,
                 lengths=None,
                 initial_state=None):
    """Computes the state of the RNN-NADE (NADE bias parameters and RNN state).

    Args:
      inputs: A batch of sequences to compute the state from, sized
          `[batch_size, max(lengths), num_dims]` or `[batch_size, num_dims]`.
      lengths: The length of each sequence, sized `[batch_size]`.
      initial_state: An RnnNadeStateTuple, the initial state of the RNN-NADE, or
          None if the zero state should be used.

    Returns:
      final_state: An RnnNadeStateTuple, the final state of the RNN-NADE.
    """
    batch_size = int(inputs.shape[0])

    if lengths is None:
      lengths = tf.tile(tf.shape(inputs)[1:2], [batch_size])
    if initial_state is None:
      initial_rnn_state = self._get_rnn_zero_state(batch_size)
    else:
      initial_rnn_state = initial_state.rnn_state

    helper = contrib_seq2seq.TrainingHelper(
        inputs=inputs, sequence_length=lengths)

    decoder = contrib_seq2seq.BasicDecoder(
        cell=self._rnn_cell,
        helper=helper,
        initial_state=initial_rnn_state,
        output_layer=self._fc_layer)

    final_outputs, final_rnn_state = contrib_seq2seq.dynamic_decode(
        decoder)[0:2]

    # Flatten time dimension.
    final_outputs_flat = magenta.common.flatten_maybe_padded_sequences(
        final_outputs.rnn_output, lengths)

    b_enc, b_dec = tf.split(
        final_outputs_flat, [self._nade.num_hidden, self._nade.num_dims],
        axis=1)

    return RnnNadeStateTuple(b_enc, b_dec, final_rnn_state)
예제 #4
0
    def _testDynamicDecodeRNN(self, time_major, maximum_iterations=None):  # pylint:disable=invalid-name
        sequence_length = [3, 4, 3, 1, 0]
        batch_size = 5
        max_time = 8
        input_depth = 7
        cell_depth = 10
        max_out = max(sequence_length)

        with self.session(use_gpu=True) as sess:
            if time_major:
                inputs = np.random.randn(max_time, batch_size,
                                         input_depth).astype(np.float32)
            else:
                inputs = np.random.randn(batch_size, max_time,
                                         input_depth).astype(np.float32)
            cell = tf.nn.rnn_cell.LSTMCell(cell_depth)
            helper = seq2seq.TrainingHelper(inputs,
                                            sequence_length,
                                            time_major=time_major)
            my_decoder = seq2seq.BasicDecoder(cell=cell,
                                              helper=helper,
                                              initial_state=cell.zero_state(
                                                  dtype=tf.float32,
                                                  batch_size=batch_size))

            final_outputs, final_state, final_sequence_length = (
                seq2seq.dynamic_decode(my_decoder,
                                       output_time_major=time_major,
                                       maximum_iterations=maximum_iterations))

            def _t(shape):
                if time_major:
                    return (shape[1], shape[0]) + shape[2:]
                return shape

            self.assertIsInstance(final_outputs, seq2seq.BasicDecoderOutput)
            self.assertIsInstance(final_state, tf.nn.rnn_cell.LSTMStateTuple)

            self.assertEqual(
                (batch_size, ),
                tuple(final_sequence_length.get_shape().as_list()))
            self.assertEqual(
                _t((batch_size, None, cell_depth)),
                tuple(final_outputs.rnn_output.get_shape().as_list()))
            self.assertEqual(
                _t((batch_size, None)),
                tuple(final_outputs.sample_id.get_shape().as_list()))

            sess.run(tf.global_variables_initializer())
            sess_results = sess.run({
                "final_outputs":
                final_outputs,
                "final_state":
                final_state,
                "final_sequence_length":
                final_sequence_length,
            })

            # Mostly a smoke test
            time_steps = max_out
            expected_length = sequence_length
            if maximum_iterations is not None:
                time_steps = min(max_out, maximum_iterations)
                expected_length = [
                    min(x, maximum_iterations) for x in expected_length
                ]
            self.assertEqual(_t((batch_size, time_steps, cell_depth)),
                             sess_results["final_outputs"].rnn_output.shape)
            self.assertEqual(_t((batch_size, time_steps)),
                             sess_results["final_outputs"].sample_id.shape)
            self.assertCountEqual(expected_length,
                                  sess_results["final_sequence_length"])