Example #1
0
    def sample(self,
               n,
               max_length=None,
               z=None,
               temperature=1.0,
               start_inputs=None,
               end_fn=None):
        """Sample from decoder with an optional conditional latent vector `z`.

    Args:
      n: Scalar number of samples to return.
      max_length: (Optional) Scalar maximum sample length to return. Required if
        data representation does not include end tokens.
      z: (Optional) Latent vectors to sample from. Required if model is
        conditional. Sized `[n, z_size]`.
      temperature: (Optional) The softmax temperature to use when sampling, if
        applicable.
      start_inputs: (Optional) Initial inputs to use for batch.
        Sized `[n, output_depth]`.
      end_fn: (Optional) A callable that takes a batch of samples (sized
        `[n, output_depth]` and emits a `bool` vector
        shaped `[batch_size]` indicating whether each sample is an end token.
    Returns:
      samples: Sampled sequences. Sized `[n, max_length, output_depth]`.
      final_state: The final states of the decoder.
    Raises:
      ValueError: If `z` is provided and its first dimension does not equal `n`.
    """
        if z is not None and z.shape[0].value != n:
            raise ValueError(
                '`z` must have a first dimension that equals `n` when given. '
                'Got: %d vs %d' % (z.shape[0].value, n))

        # Use a dummy Z in unconditional case.
        z = tf.zeros((n, 0), tf.float32) if z is None else z

        # If not given, start with zeros.
        start_inputs = start_inputs if start_inputs is not None else tf.zeros(
            [n, self._output_depth], dtype=tf.float32)
        # In the conditional case, also concatenate the Z.
        start_inputs = tf.concat([start_inputs, z], axis=-1)

        sample_fn = lambda x: self._sample(x, temperature)
        end_fn = end_fn or (lambda x: False)
        # In the conditional case, concatenate Z to the sampled value.
        next_inputs_fn = lambda x: tf.concat([x, z], axis=-1)

        sampler = seq2seq.InferenceHelper(sample_fn,
                                          sample_shape=[self._output_depth],
                                          sample_dtype=tf.float32,
                                          start_inputs=start_inputs,
                                          end_fn=end_fn,
                                          next_inputs_fn=next_inputs_fn)

        decoder_outputs, final_state = self._decode(z,
                                                    helper=sampler,
                                                    max_length=max_length)

        return decoder_outputs.sample_id, final_state
Example #2
0
    def sample(self,
               n,
               max_length=None,
               z=None,
               temperature=1.0,
               start_inputs=None,
               end_fn=None):
        if z is not None and z.shape[0].value != n:
            raise ValueError(
                '`z` must have a first dimension that equals `n` when given. '
                'Got: %d vs %d' % (z.shape[0].value, n))

        # Use a dummy Z in unconditional case.
        z = tf.zeros((n, 0), tf.float32) if z is None else z

        # If not given, start with zeros.
        start_inputs = start_inputs if start_inputs is not None else tf.zeros(
            [n, self._output_depth], dtype=tf.float32)
        # In the conditional case, also concatenate the Z.
        start_inputs = tf.concat([start_inputs, z], axis=-1)

        sample_fn = lambda x: self._sample(x, temperature)
        end_fn = end_fn or (lambda x: False)
        # In the conditional case, concatenate Z to the sampled value.
        next_inputs_fn = lambda x: tf.concat([x, z], axis=-1)

        sampler = seq2seq.InferenceHelper(sample_fn,
                                          sample_shape=[self._output_depth],
                                          sample_dtype=tf.float32,
                                          start_inputs=start_inputs,
                                          end_fn=end_fn,
                                          next_inputs_fn=next_inputs_fn)

        decoder_outputs = self._decode(n,
                                       helper=sampler,
                                       z=z,
                                       max_length=max_length)

        return decoder_outputs.sample_id
    # 7. Output layer
    out = tf.matmul(h_flat_drop, W_fc1_) + b_fc1_
    if clas:
        return tf.reshape(out, [b_size, seq_size, -1])
    else:
        return tf.reshape(out, [b_size, -1])


# In[18]:


# Helper without embedding, can add param: 'next_inputs_fn'
helper_infer = seq2seq.InferenceHelper(
    sample_fn=(lambda x: x),
    sample_shape=[letter_size],
    sample_dtype=tf.float32,
    start_inputs=tf.cast(tf.fill([batch_size, letter_size], LETTER_PAD), tf.float32), # PAD <- EOS, need flaot32
    end_fn=(lambda sample_ids:
            tf.equal(tf.argmax(CNN_2(sample_ids, False), axis=-1, output_type=tf.int32), 1)))
            

decoder_infer = seq2seq.BasicDecoder(
    decoder_cell, helper_infer, decoder_initial_state,
    output_layer=projection_layer)

# Dynamic decoding
outputs_infer, final_context_state, final_seq_lengths = seq2seq.dynamic_decode(
    decoder_infer,
    impute_finished=True,
    maximum_iterations=output_length)