Esempio n. 1
0
  def _decode(self, z, helper, input_shape, max_length=None):
    """Decodes the given batch of latent vectors vectors, which may be 0-length.

    Args:
      z: Batch of latent vectors, sized `[batch_size, z_size]`, where `z_size`
        may be 0 for unconditioned decoding.
      helper: A seq2seq.Helper to use. If a TrainingHelper is passed and a
        CudnnLSTM has previously been defined, it will be used instead.
      input_shape: The shape of each model input vector passed to the decoder.
      max_length: (Optional) The maximum iterations to decode.

    Returns:
      results: The LstmDecodeResults.
    """
    initial_state = lstm_utils.initial_cell_state_from_embedding(
        self._dec_cell, z, name='decoder/z_to_initial_state')

    # CudnnLSTM does not support sampling so it can only replace TrainingHelper.
    if  self._cudnn_dec_lstm and type(helper) is seq2seq.TrainingHelper:  # pylint:disable=unidiomatic-typecheck
      rnn_output, _ = self._cudnn_dec_lstm(
          tf.transpose(helper.inputs, [1, 0, 2]),
          initial_state=lstm_utils.state_tuples_to_cudnn_lstm_state(
              initial_state),
          training=self._is_training)
      with tf.variable_scope('decoder'):
        rnn_output = self._output_layer(rnn_output)

      results = lstm_utils.LstmDecodeResults(
          rnn_input=helper.inputs[:, :, :self._output_depth],
          rnn_output=tf.transpose(rnn_output, [1, 0, 2]),
          samples=tf.zeros([z.shape[0], 0]),
          # TODO(adarob): Pass the final state when it is valid (fixed-length).
          final_state=None,
          final_sequence_lengths=helper.sequence_length)
    else:
      if self._cudnn_dec_lstm:
        tf.logging.warning(
            'CudnnLSTM does not support sampling. Using `dynamic_decode` '
            'instead.')
      decoder = lstm_utils.Seq2SeqLstmDecoder(
          self._dec_cell,
          helper,
          initial_state=initial_state,
          input_shape=input_shape,
          output_layer=self._output_layer)
      final_output, final_state, final_lengths = seq2seq.dynamic_decode(
          decoder,
          maximum_iterations=max_length,
          swap_memory=True,
          scope='decoder')
      results = lstm_utils.LstmDecodeResults(
          rnn_input=final_output.rnn_input[:, :, :self._output_depth],
          rnn_output=final_output.rnn_output,
          samples=final_output.sample_id,
          final_state=final_state,
          final_sequence_lengths=final_lengths)

    return results
Esempio n. 2
0
 def _merge_decode_results(self, decode_results):
   """Merge across time."""
   assert decode_results
   time_axis = 1
   zipped_results = lstm_utils.LstmDecodeResults(*zip(*decode_results))
   return lstm_utils.LstmDecodeResults(
       rnn_output=(None if zipped_results.rnn_output[0] is None else
                   tf.concat(zipped_results.rnn_output, axis=time_axis)),
       rnn_input=(None if zipped_results.rnn_input[0] is None else
                  tf.concat(zipped_results.rnn_input, axis=time_axis)),
       samples=tf.concat(zipped_results.samples, axis=time_axis),
       final_state=zipped_results.final_state[-1],
       final_sequence_lengths=tf.stack(
           zipped_results.final_sequence_lengths, axis=time_axis))
Esempio n. 3
0
  def _decode(self, z, helper, input_shape, max_length=None):
    """Decodes the given batch of latent vectors vectors, which may be 0-length.

    Args:
      z: Batch of latent vectors, sized `[batch_size, z_size]`, where `z_size`
        may be 0 for unconditioned decoding.
      helper: A seq2seq.Helper to use.
      input_shape: The shape of each model input vector passed to the decoder.
      max_length: (Optional) The maximum iterations to decode.

    Returns:
      results: The LstmDecodeResults.
    """
    initial_state = lstm_utils.initial_cell_state_from_embedding(
        self._dec_cell, z, name='decoder/z_to_initial_state')

    decoder = lstm_utils.Seq2SeqLstmDecoder(
        self._dec_cell,
        helper,
        initial_state=initial_state,
        input_shape=input_shape,
        output_layer=self._output_layer)
    final_output, final_state, final_lengths = contrib_seq2seq.dynamic_decode(
        decoder,
        maximum_iterations=max_length,
        swap_memory=True,
        scope='decoder')
    results = lstm_utils.LstmDecodeResults(
        rnn_input=final_output.rnn_input[:, :, :self._output_depth],
        rnn_output=final_output.rnn_output,
        samples=final_output.sample_id,
        final_state=final_state,
        final_sequence_lengths=final_lengths)

    return results
Esempio n. 4
0
 def _merge_decode_results(self, decode_results):
   """Merge in the output dimension."""
   output_axis = -1
   assert decode_results
   zipped_results = lstm_utils.LstmDecodeResults(*zip(*decode_results))
   with tf.control_dependencies([
       tf.assert_equal(
           zipped_results.final_sequence_lengths, self.hparams.max_seq_len,
           message='Variable length not supported by '
                   'MultiOutCategoricalLstmDecoder.')]):
     return lstm_utils.LstmDecodeResults(
         rnn_output=tf.concat(zipped_results.rnn_output, axis=output_axis),
         rnn_input=tf.concat(zipped_results.rnn_input, axis=output_axis),
         samples=tf.concat(zipped_results.samples, axis=output_axis),
         final_state=(
             None if zipped_results.final_state[0] is None else
             nest.map_structure(lambda x: tf.concat(x, axis=output_axis),
                                zipped_results.final_state)),
         final_sequence_lengths=zipped_results.final_sequence_lengths[0])
Esempio n. 5
0
  def sample(self, n, max_length=None, z=None, c_input=None, temperature=None,
             start_inputs=None, beam_width=None, end_token=None):
    """Overrides BaseLstmDecoder `sample` method to add optional beam search.

    Args:
      n: Scalar number of samples to return.
      max_length: (Optional) Scalar maximum sample length to return. Required if
        data representation does not include end tokens.
      z: (Optional) Latent vectors to sample from. Required if model is
        conditional. Sized `[n, z_size]`.
      c_input: (Optional) Control sequence, sized `[max_length, control_depth]`.
      temperature: (Optional) The softmax temperature to use when not doing beam
        search. Defaults to 1.0. Ignored when `beam_width` is provided.
      start_inputs: (Optional) Initial inputs to use for batch.
        Sized `[n, output_depth]`.
      beam_width: (Optional) Width of beam to use for beam search. Beam search
        is disabled if not provided.
      end_token: (Optional) Scalar token signaling the end of the sequence to
        use for early stopping.
    Returns:
      samples: Sampled sequences. Sized `[n, max_length, output_depth]`.
      final_state: The final states of the decoder.
    Raises:
      ValueError: If `z` is provided and its first dimension does not equal `n`.
    """
    if beam_width is None:
      end_fn = (None if end_token is None else
                lambda x: tf.equal(tf.argmax(x, axis=-1), end_token))
      return super(CategoricalLstmDecoder, self).sample(
          n, max_length, z, c_input, temperature, start_inputs, end_fn)

    # If `end_token` is not given, use an impossible value.
    end_token = self._output_depth if end_token is None else end_token
    if z is not None and z.shape[0].value != n:
      raise ValueError(
          '`z` must have a first dimension that equals `n` when given. '
          'Got: %d vs %d' % (z.shape[0].value, n))

    if temperature is not None:
      tf.logging.warning('`temperature` is ignored when using beam search.')
    # Use a dummy Z in unconditional case.
    z = tf.zeros((n, 0), tf.float32) if z is None else z

    # If not given, start with dummy `-1` token and replace with zero vectors in
    # `embedding_fn`.
    start_tokens = (
        tf.argmax(start_inputs, axis=-1, output_type=tf.int32)
        if start_inputs is not None else
        -1 * tf.ones([n], dtype=tf.int32))

    initial_state = lstm_utils.initial_cell_state_from_embedding(
        self._dec_cell, z, name='decoder/z_to_initial_state')
    beam_initial_state = seq2seq.tile_batch(
        initial_state, multiplier=beam_width)

    # Tile `z` across beams.
    beam_z = tf.tile(tf.expand_dims(z, 1), [1, beam_width, 1])

    def embedding_fn(tokens):
      # If tokens are the start_tokens (negative), replace with zero vectors.
      next_inputs = tf.cond(
          tf.less(tokens[0, 0], 0),
          lambda: tf.zeros([n, beam_width, self._output_depth]),
          lambda: tf.one_hot(tokens, self._output_depth))

      # Concatenate `z` to next inputs.
      next_inputs = tf.concat([next_inputs, beam_z], axis=-1)
      return next_inputs

    decoder = seq2seq.BeamSearchDecoder(
        self._dec_cell,
        embedding_fn,
        start_tokens,
        end_token,
        beam_initial_state,
        beam_width,
        output_layer=self._output_layer,
        length_penalty_weight=0.0)

    final_output, final_state, final_lengths = seq2seq.dynamic_decode(
        decoder,
        maximum_iterations=max_length,
        swap_memory=True,
        scope='decoder')

    samples = tf.one_hot(final_output.predicted_ids[:, :, 0],
                         self._output_depth)
    # Rebuild the input by combining the inital input with the sampled output.
    initial_inputs = (
        tf.zeros([n, 1, self._output_depth]) if start_inputs is None else
        tf.expand_dims(start_inputs, axis=1))
    rnn_input = tf.concat([initial_inputs, samples[:, :-1]], axis=1)

    results = lstm_utils.LstmDecodeResults(
        rnn_input=rnn_input,
        rnn_output=None,
        samples=samples,
        final_state=nest.map_structure(
            lambda x: x[:, 0], final_state.cell_state),
        final_sequence_lengths=final_lengths[:, 0])
    return samples, results