コード例 #1
0
ファイル: lstm_utils_test.py プロジェクト: ThierryGrb/magenta
  def testStateTupleToCudnnLstmState(self):
    with self.test_session():
      h, c = lstm_utils.state_tuples_to_cudnn_lstm_state(
          (rnn.LSTMStateTuple(h=np.arange(10).reshape(5, 2),
                              c=np.arange(10, 20).reshape(5, 2)),))
      self.assertAllEqual(np.arange(10).reshape(1, 5, 2), h.eval())
      self.assertAllEqual(np.arange(10, 20).reshape(1, 5, 2), c.eval())

      h, c = lstm_utils.state_tuples_to_cudnn_lstm_state(
          (rnn.LSTMStateTuple(h=np.arange(10).reshape(5, 2),
                              c=np.arange(20, 30).reshape(5, 2)),
           rnn.LSTMStateTuple(h=np.arange(10, 20).reshape(5, 2),
                              c=np.arange(30, 40).reshape(5, 2))))
      self.assertAllEqual(np.arange(20).reshape(2, 5, 2), h.eval())
      self.assertAllEqual(np.arange(20, 40).reshape(2, 5, 2), c.eval())
コード例 #2
0
    def testStateTupleToCudnnLstmState(self):
        with self.test_session():
            h, c = lstm_utils.state_tuples_to_cudnn_lstm_state(
                (rnn.LSTMStateTuple(h=np.arange(10).reshape(5, 2),
                                    c=np.arange(10, 20).reshape(5, 2)), ))
            self.assertAllEqual(np.arange(10).reshape(1, 5, 2), h.eval())
            self.assertAllEqual(np.arange(10, 20).reshape(1, 5, 2), c.eval())

            h, c = lstm_utils.state_tuples_to_cudnn_lstm_state(
                (rnn.LSTMStateTuple(h=np.arange(10).reshape(5, 2),
                                    c=np.arange(20, 30).reshape(5, 2)),
                 rnn.LSTMStateTuple(h=np.arange(10, 20).reshape(5, 2),
                                    c=np.arange(30, 40).reshape(5, 2))))
            self.assertAllEqual(np.arange(20).reshape(2, 5, 2), h.eval())
            self.assertAllEqual(np.arange(20, 40).reshape(2, 5, 2), c.eval())
コード例 #3
0
ファイル: lstm_models.py プロジェクト: yyzreal/magenta
  def _decode(self, z, helper, input_shape, max_length=None):
    """Decodes the given batch of latent vectors vectors, which may be 0-length.

    Args:
      z: Batch of latent vectors, sized `[batch_size, z_size]`, where `z_size`
        may be 0 for unconditioned decoding.
      helper: A seq2seq.Helper to use. If a TrainingHelper is passed and a
        CudnnLSTM has previously been defined, it will be used instead.
      input_shape: The shape of each model input vector passed to the decoder.
      max_length: (Optional) The maximum iterations to decode.

    Returns:
      results: The LstmDecodeResults.
    """
    initial_state = lstm_utils.initial_cell_state_from_embedding(
        self._dec_cell, z, name='decoder/z_to_initial_state')

    # CudnnLSTM does not support sampling so it can only replace TrainingHelper.
    if  self._cudnn_dec_lstm and type(helper) is seq2seq.TrainingHelper:  # pylint:disable=unidiomatic-typecheck
      rnn_output, _ = self._cudnn_dec_lstm(
          tf.transpose(helper.inputs, [1, 0, 2]),
          initial_state=lstm_utils.state_tuples_to_cudnn_lstm_state(
              initial_state),
          training=self._is_training)
      with tf.variable_scope('decoder'):
        rnn_output = self._output_layer(rnn_output)

      results = lstm_utils.LstmDecodeResults(
          rnn_input=helper.inputs[:, :, :self._output_depth],
          rnn_output=tf.transpose(rnn_output, [1, 0, 2]),
          samples=tf.zeros([z.shape[0], 0]),
          # TODO(adarob): Pass the final state when it is valid (fixed-length).
          final_state=None,
          final_sequence_lengths=helper.sequence_length)
    else:
      if self._cudnn_dec_lstm:
        tf.logging.warning(
            'CudnnLSTM does not support sampling. Using `dynamic_decode` '
            'instead.')
      decoder = lstm_utils.Seq2SeqLstmDecoder(
          self._dec_cell,
          helper,
          initial_state=initial_state,
          input_shape=input_shape,
          output_layer=self._output_layer)
      final_output, final_state, final_lengths = seq2seq.dynamic_decode(
          decoder,
          maximum_iterations=max_length,
          swap_memory=True,
          scope='decoder')
      results = lstm_utils.LstmDecodeResults(
          rnn_input=final_output.rnn_input[:, :, :self._output_depth],
          rnn_output=final_output.rnn_output,
          samples=final_output.sample_id,
          final_state=final_state,
          final_sequence_lengths=final_lengths)

    return results