def _symbols_to_logits_fn(self,
                              embedding,
                              vocab_size,
                              mode,
                              output_layer=None,
                              dtype=None):
        embedding_fn = get_embedding_fn(embedding)
        if output_layer is None:
            output_layer = build_output_layer(self.num_units,
                                              vocab_size,
                                              dtype=dtype)

        def _impl(ids, step, cache):
            inputs = embedding_fn(ids[:, -1:])
            inputs *= self.num_units**0.5
            inputs = self.position_encoder.apply_one(inputs, step + 1)
            outputs = self._self_attention_stack(inputs,
                                                 mode=mode,
                                                 cache=cache,
                                                 memory=cache["memory"],
                                                 memory_sequence_length=None)
            outputs = outputs[:, -1:, :]
            logits = output_layer(outputs)
            return logits, cache

        return _impl
    def decode(self,
               inputs,
               sequence_length,
               vocab_size=None,
               initial_state=None,
               sampling_probability=None,
               embedding=None,
               output_layer=None,
               mode=tf.estimator.ModeKeys.TRAIN,
               memory=None,
               memory_sequence_length=None):
        if sampling_probability is not None:
            raise ValueError(
                "Scheduled sampling is not supported with SelfAttentionDecoder"
            )

        inputs *= self.num_units**0.5
        if self.position_encoder is not None:
            inputs = self.position_encoder(inputs,
                                           sequence_length=sequence_length)

        outputs = self._self_attention_stack(
            inputs,
            sequence_length=sequence_length,
            mode=mode,
            memory=memory,
            memory_sequence_length=memory_sequence_length)

        if output_layer is None:
            output_layer = build_output_layer(self.num_units,
                                              vocab_size,
                                              dtype=inputs.dtype)
        logits = output_layer(outputs)

        return (logits, None, sequence_length)
Beispiel #3
0
  def decode(self,
             inputs,
             sequence_length,
             vocab_size=None,
             initial_state=None,
             sampling_probability=None,
             embedding=None,
             output_layer=None,
             mode=tf.estimator.ModeKeys.TRAIN,
             memory=None,
             memory_sequence_length=None):
    _ = memory
    _ = memory_sequence_length

    batch_size = tf.shape(inputs)[0]

    if (sampling_probability is not None
        and (tf.contrib.framework.is_tensor(sampling_probability)
             or sampling_probability > 0.0)):
      if embedding is None:
        raise ValueError("embedding argument must be set when using scheduled sampling")

      tf.summary.scalar("sampling_probability", sampling_probability)
      helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(
          inputs,
          sequence_length,
          embedding,
          sampling_probability)
    else:
      helper = tf.contrib.seq2seq.TrainingHelper(inputs, sequence_length)

    cell, initial_state = self._build_cell(
        mode,
        batch_size,
        initial_state=initial_state,
        memory=memory,
        memory_sequence_length=memory_sequence_length,
        dtype=inputs.dtype)

    if output_layer is None:
      output_layer = build_output_layer(self.num_units, vocab_size, dtype=inputs.dtype)

    # With TrainingHelper, project all timesteps at once.
    fused_projection = isinstance(helper, tf.contrib.seq2seq.TrainingHelper)

    decoder = tf.contrib.seq2seq.BasicDecoder(
        cell,
        helper,
        initial_state,
        output_layer=output_layer if not fused_projection else None)

    outputs, state, length = tf.contrib.seq2seq.dynamic_decode(decoder)

    if fused_projection and output_layer is not None:
      logits = output_layer(outputs.rnn_output)
    else:
      logits = outputs.rnn_output

    return (logits, state, length)
Beispiel #4
0
    def dynamic_decode_and_search(self,
                                  embedding,
                                  start_tokens,
                                  end_token,
                                  vocab_size=None,
                                  initial_state=None,
                                  output_layer=None,
                                  beam_width=5,
                                  length_penalty=0.0,
                                  maximum_iterations=250,
                                  mode=tf.estimator.ModeKeys.PREDICT,
                                  memory=None,
                                  memory_sequence_length=None,
                                  dtype=None):
        batch_size = tf.shape(start_tokens)[0]

        # Replicate batch `beam_width` times.
        if initial_state is not None:
            initial_state = tf.contrib.seq2seq.tile_batch(
                initial_state, multiplier=beam_width)
        if memory is not None:
            memory = tf.contrib.seq2seq.tile_batch(memory,
                                                   multiplier=beam_width)
        if memory_sequence_length is not None:
            memory_sequence_length = tf.contrib.seq2seq.tile_batch(
                memory_sequence_length, multiplier=beam_width)

        cell, initial_state = self._build_cell(
            mode,
            batch_size * beam_width,
            initial_state=initial_state,
            memory=memory,
            memory_sequence_length=memory_sequence_length,
            dtype=dtype)

        if output_layer is None:
            output_layer = build_output_layer(self.num_units,
                                              vocab_size,
                                              dtype=dtype or memory.dtype)

        decoder = tf.contrib.seq2seq.BeamSearchDecoder(
            cell,
            embedding,
            start_tokens,
            end_token,
            initial_state,
            beam_width,
            output_layer=output_layer,
            length_penalty_weight=length_penalty)

        outputs, beam_state, length = tf.contrib.seq2seq.dynamic_decode(
            decoder, maximum_iterations=maximum_iterations)

        predicted_ids = tf.transpose(outputs.predicted_ids, perm=[0, 2, 1])
        log_probs = beam_state.log_probs
        state = beam_state.cell_state

        return (predicted_ids, state, length, log_probs)
Beispiel #5
0
  def dynamic_decode(self,
                     embedding,
                     start_tokens,
                     end_token,
                     vocab_size=None,
                     initial_state=None,
                     output_layer=None,
                     maximum_iterations=250,
                     mode=tf.estimator.ModeKeys.PREDICT,
                     memory=None,
                     memory_sequence_length=None,
                     dtype=None,
                     return_alignment_history=False):
    batch_size = tf.shape(start_tokens)[0]

    helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
        embedding,
        start_tokens,
        end_token)

    cell, initial_state = self._build_cell(
        mode,
        batch_size,
        initial_state=initial_state,
        memory=memory,
        memory_sequence_length=memory_sequence_length,
        dtype=dtype,
        alignment_history=return_alignment_history)

    if output_layer is None:
      output_layer = build_output_layer(self.num_units, vocab_size, dtype=dtype or memory.dtype)

    decoder = tf.contrib.seq2seq.BasicDecoder(
        cell,
        helper,
        initial_state,
        output_layer=output_layer)

    outputs, state, length = tf.contrib.seq2seq.dynamic_decode(
        decoder, maximum_iterations=maximum_iterations)

    predicted_ids = outputs.sample_id
    log_probs = logits_to_cum_log_probs(outputs.rnn_output, length)

    # Make shape consistent with beam search.
    predicted_ids = tf.expand_dims(predicted_ids, 1)
    length = tf.expand_dims(length, 1)
    log_probs = tf.expand_dims(log_probs, 1)

    if return_alignment_history:
      alignment_history = _get_alignment_history(state)
      if alignment_history is not None:
        alignment_history = tf.expand_dims(alignment_history, 1)
      return (predicted_ids, state, length, log_probs, alignment_history)
    return (predicted_ids, state, length, log_probs)
Beispiel #6
0
    def decode(self,
               inputs,
               sequence_length,
               vocab_size=None,
               initial_state=None,
               sampling_probability=None,
               embedding=None,
               output_layer=None,
               mode=tf.estimator.ModeKeys.TRAIN,
               memory=None,
               memory_sequence_length=None,
               return_alignment_history=False):
        _ = memory
        _ = memory_sequence_length

        batch_size = tf.shape(inputs)[0]

        if (sampling_probability is not None
                and (tf.contrib.framework.is_tensor(sampling_probability)
                     or sampling_probability > 0.0)):
            if embedding is None:
                raise ValueError(
                    "embedding argument must be set when using scheduled sampling"
                )

            tf.summary.scalar("sampling_probability", sampling_probability)
            helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(
                inputs, sequence_length, embedding, sampling_probability)
            fused_projection = False
        else:
            helper = tf.contrib.seq2seq.TrainingHelper(inputs, sequence_length)
            fused_projection = True  # With TrainingHelper, project all timesteps at once.

        cell, initial_state = self._build_cell(
            mode,
            batch_size,
            initial_state=initial_state,
            memory=memory,
            memory_sequence_length=memory_sequence_length,
            dtype=inputs.dtype,
            alignment_history=return_alignment_history)

        if output_layer is None:
            output_layer = build_output_layer(self.num_units,
                                              vocab_size,
                                              dtype=inputs.dtype)

        decoder = tf.contrib.seq2seq.BasicDecoder(
            cell,
            helper,
            initial_state,
            output_layer=output_layer if not fused_projection else None)

        outputs, state, length = tf.contrib.seq2seq.dynamic_decode(decoder)

        if fused_projection and output_layer is not None:
            logits = output_layer(outputs.rnn_output)
        else:
            logits = outputs.rnn_output
        # Make sure outputs have the same time_dim as inputs
        inputs_len = tf.shape(inputs)[1]
        logits = align_in_time(logits, inputs_len)

        if return_alignment_history:
            alignment_history = _get_alignment_history(state)
            if alignment_history is not None:
                alignment_history = align_in_time(alignment_history,
                                                  inputs_len)
            return (logits, state, length, alignment_history)
        return (logits, state, length)
Beispiel #7
0
    def dynamic_decode_and_search(self,
                                  embedding,
                                  start_tokens,
                                  end_token,
                                  vocab_size=None,
                                  initial_state=None,
                                  output_layer=None,
                                  beam_width=5,
                                  length_penalty=0.0,
                                  maximum_iterations=250,
                                  mode=tf.estimator.ModeKeys.PREDICT,
                                  memory=None,
                                  memory_sequence_length=None,
                                  dtype=None,
                                  return_alignment_history=False):
        if (return_alignment_history and "reorder_tensor_arrays"
                not in fn_args(tf.contrib.seq2seq.BeamSearchDecoder.__init__)):
            tf.logging.warn(
                "The current version of tf.contrib.seq2seq.BeamSearchDecoder "
                "does not support returning the alignment history. None will "
                "be returned instead. Consider upgrading TensorFlow.")
            alignment_history = False
        else:
            alignment_history = return_alignment_history

        batch_size = tf.shape(start_tokens)[0]

        # Replicate batch `beam_width` times.
        if initial_state is not None:
            initial_state = tf.contrib.seq2seq.tile_batch(
                initial_state, multiplier=beam_width)
        if memory is not None:
            memory = tf.contrib.seq2seq.tile_batch(memory,
                                                   multiplier=beam_width)
        if memory_sequence_length is not None:
            memory_sequence_length = tf.contrib.seq2seq.tile_batch(
                memory_sequence_length, multiplier=beam_width)

        cell, initial_state = self._build_cell(
            mode,
            batch_size * beam_width,
            initial_state=initial_state,
            memory=memory,
            memory_sequence_length=memory_sequence_length,
            dtype=dtype,
            alignment_history=alignment_history)

        if output_layer is None:
            output_layer = build_output_layer(self.num_units,
                                              vocab_size,
                                              dtype=dtype or memory.dtype)

        decoder = tf.contrib.seq2seq.BeamSearchDecoder(
            cell,
            embedding,
            start_tokens,
            end_token,
            initial_state,
            beam_width,
            output_layer=output_layer,
            length_penalty_weight=length_penalty)

        outputs, beam_state, length = tf.contrib.seq2seq.dynamic_decode(
            decoder, maximum_iterations=maximum_iterations)

        predicted_ids = tf.transpose(outputs.predicted_ids, perm=[0, 2, 1])
        log_probs = beam_state.log_probs
        state = beam_state.cell_state

        if return_alignment_history:
            alignment_history = _get_alignment_history(state)
            if alignment_history is not None:
                alignment_history = tf.reshape(
                    alignment_history,
                    [batch_size, beam_width, -1,
                     tf.shape(memory)[1]])
            return (predicted_ids, state, length, log_probs, alignment_history)
        return (predicted_ids, state, length, log_probs)
    def decode(self,
               inputs,
               sequence_length,
               enc_batch_extend_vocab,
               max_art_oovs,
               vocab_size=None,
               initial_state=None,
               sampling_probability=None,
               embedding=None,
               output_layer=None,
               mode=tf.estimator.ModeKeys.TRAIN,
               memory=None,
               memory_sequence_length=None,
               return_alignment_history=False):
        _ = memory
        _ = memory_sequence_length

        batch_size = tf.shape(inputs)[0]

        if (sampling_probability is not None
                and (tf.contrib.framework.is_tensor(sampling_probability)
                     or sampling_probability > 0.0)):
            if embedding is None:
                raise ValueError(
                    "embedding argument must be set when using scheduled sampling"
                )

            tf.summary.scalar("sampling_probability", sampling_probability)
            helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(
                inputs, sequence_length, embedding, sampling_probability)
            fused_projection = False
        else:
            helper = PointerGeneratorGreedyEmbeddingHelper(
                self.embeddings,
                tf.ones([batch_size], tf.int32) * 2, 3)
            fused_projection = True  # With TrainingHelper, project all timesteps at once.

        cell, initial_state = self._build_cell(
            mode,
            batch_size,
            initial_state=initial_state,
            memory=memory,
            memory_sequence_length=memory_sequence_length,
            dtype=inputs.dtype)

        if output_layer is None:
            output_layer = decoder.build_output_layer(self.output_size,
                                                      vocab_size,
                                                      dtype=inputs.dtype)

        basic_decoder = PointerGeneratorDecoder(enc_batch_extend_vocab,
                                                max_art_oovs,
                                                self.coverage,
                                                cell,
                                                helper,
                                                initial_state,
                                                output_layer=output_layer)

        outputs, state, length = tf.contrib.seq2seq.dynamic_decode(
            basic_decoder, maximum_iterations=FLAGS.max_dec_steps)

        # if fused_projection and output_layer is not None:
        #     logits = output_layer(outputs.rnn_output)
        # else:
        logits = outputs.rnn_output
        # Make sure outputs have the same time_dim as inputs
        inputs_len = tf.shape(inputs)[1]
        logits = align_in_time(logits, inputs_len)

        if return_alignment_history:
            alignment_history = self._get_attention(state)
            if alignment_history is not None:
                alignment_history = align_in_time(alignment_history,
                                                  inputs_len)
            return (logits, state, length, alignment_history)
        return (logits, state, length)