def symbols_to_logits_fn(logit_indices, current_index):
        """Go from targets to logits."""
        logit_indices = tf.expand_dims(logit_indices, 0)
        decode_steps = decode_utils.get_decode_steps(logit_indices,
                                                     output_vocab_size,
                                                     model_config)
        target_embeddings = _get_target_embeddings(
            input_embeddings, output_vocab_embeddings_table, decode_steps,
            model_config)
        decoder_output = _build_transformer_decoder(
            encoder_output,
            source_len,
            target_embeddings,
            mode,
            model_config,
            single_step_index=current_index)

        logits = _get_action_logits(encoder_output,
                                    decoder_output,
                                    output_vocab_embeddings_table,
                                    output_vocab_size,
                                    model_config,
                                    input_copy_mask=input_copy_mask,
                                    clean_output_mask=clean_output_mask)

        # Squeeze batch dimension and length dimension, as both should be 1.
        logits = tf.squeeze(logits, axis=[0, 1])
        # Shape of logits should now be (output_vocab_size).
        return logits
Example #2
0
def train(
    model_config,
    input_embeddings,
    source_len,
    output_vocab_size,
    output_vocab_embeddings_table,
    target_decode_steps,
    mode,
    input_copy_mask=None,
):
    """Constructs encoder and decoder transformation for training and eval.

  In the shapes described below, B is batch size, L is sequence length,
  D is the dimensionality of the model embeddings, and T is the output vocab
  size.

  Args:
    model_config: ModelConfig proto.
    input_embeddings: Tensor of shape (B, L, D) representing inputs.
    source_len: Tensor of shape (B) containing length of each input sequence.
    output_vocab_size: Size of output vocabulary.
    output_vocab_embeddings_table: Tensor of shape (T, D) representing table of
      embeddings for output symbols.
    target_decode_steps: DecodeSteps representing target outputs. Each tensor
      has shape (B, L).
    mode: Enum indicating model mode, TRAIN or EVAL.
    input_copy_mask: Mask for copying actions.

  Returns:
    Tuple of (logits, predicted_ids), where logits is a tensor of shape
    (B, L, T) representing model output logits, and predicted_ids is
    a tensor of shape (B, L) containing the derived integer IDs of the
    one-best output symbol.
  """
    logits = _transformer_body(
        input_embeddings,
        source_len,
        target_decode_steps,
        mode,
        model_config,
        output_vocab_size,
        output_vocab_embeddings_table,
        input_copy_mask=input_copy_mask,
    )

    predicted_ids = tf.to_int32(tf.argmax(logits, axis=-1))
    output_decode_steps = decode_utils.get_decode_steps(
        predicted_ids, output_vocab_size, model_config)
    predictions = decode_utils.get_predictions(output_decode_steps)
    return logits, predictions
Example #3
0
    def symbols_to_logits_fn(current_index, logit_indices):
        """Go from targets to logits.

    Args:
      current_index: Integer corresponding to 0-indexed decoder step.
      logit_indices: Tensor of shape [batch_size * beam_width, decode_length +
        1] to input to decoder.

    Returns:
      Tensor of shape [batch_size * beam_width, output_vocab_size] representing
      logits for the current decoder step.

    Raises:
      ValueError if inputs do not have static length.
    """
        decode_steps = decode_utils.get_decode_steps(logit_indices,
                                                     output_vocab_size,
                                                     model_config)
        target_embeddings = _get_target_embeddings(
            input_embeddings, output_vocab_embeddings_table, decode_steps,
            model_config)
        decoder_output = _build_transformer_decoder(
            encoder_output,
            source_len,
            target_embeddings,
            mode,
            model_config,
            single_step_index=current_index,
        )
        logits = _get_action_logits(
            encoder_output,
            decoder_output,
            output_vocab_embeddings_table,
            output_vocab_size,
            model_config,
            input_copy_mask=input_copy_mask,
            clean_output_mask=clean_output_mask,
        )
        # Squeeze length dimension, as it should be 1.
        logits = tf.squeeze(logits, axis=[1])
        # Shape of logits should now be:
        # [batch_size * beam_width, output_vocab_size].
        return logits
def _greedy_decode(input_embeddings,
                   output_vocab_size,
                   target_end_id,
                   target_start_id,
                   output_vocab_embeddings_table,
                   source_len,
                   model_config,
                   mode,
                   input_copy_mask=None,
                   clean_output_mask=None):
    """Fast decoding."""
    encoder_output = common_layers.linear_transform(
        input_embeddings,
        output_size=model_config.model_parameters.encoder_dims,
        scope="bert_to_transformer")

    decode_length = model_config.data_options.max_decode_length

    # Expand the inputs in to the beam width.
    def symbols_to_logits_fn(logit_indices, current_index):
        """Go from targets to logits."""
        logit_indices = tf.expand_dims(logit_indices, 0)
        decode_steps = decode_utils.get_decode_steps(logit_indices,
                                                     output_vocab_size,
                                                     model_config)
        target_embeddings = _get_target_embeddings(
            input_embeddings, output_vocab_embeddings_table, decode_steps,
            model_config)
        decoder_output = _build_transformer_decoder(
            encoder_output,
            source_len,
            target_embeddings,
            mode,
            model_config,
            single_step_index=current_index)

        logits = _get_action_logits(encoder_output,
                                    decoder_output,
                                    output_vocab_embeddings_table,
                                    output_vocab_size,
                                    model_config,
                                    input_copy_mask=input_copy_mask,
                                    clean_output_mask=clean_output_mask)

        # Squeeze batch dimension and length dimension, as both should be 1.
        logits = tf.squeeze(logits, axis=[0, 1])
        # Shape of logits should now be (output_vocab_size).
        return logits

    def loop_cond(i, decoded_ids, unused_logprobs):
        """Loop conditional that returns false to stop loop."""
        return tf.logical_and(
            tf.reduce_all(tf.not_equal(decoded_ids, target_end_id)),
            tf.less(i, decode_length))

    def inner_loop(i, decoded_ids, logprobs):
        """Decoder function invoked on each while loop iteration."""
        logits = symbols_to_logits_fn(decoded_ids, i)
        next_id = tf.argmax(logits, axis=0)
        softmax = tf.nn.softmax(logits)
        extended_vocab_size = tf.shape(softmax)[-1]
        mask = tf.one_hot(next_id, extended_vocab_size)
        prob = tf.reduce_sum(softmax * mask)
        logprob = tf.log(prob)

        # Add one-hot values to output Tensors, since values at index > i+1 should
        # still be zero.
        logprobs += tf.one_hot(i + 1,
                               decode_length + 1,
                               on_value=logprob,
                               dtype=tf.float32)
        decoded_ids += tf.one_hot(i + 1,
                                  decode_length + 1,
                                  on_value=next_id,
                                  dtype=tf.int64)

        return i + 1, decoded_ids, logprobs

    initial_ids = tf.zeros(dtype=tf.int64, shape=[decode_length + 1])
    initial_ids += tf.one_hot(0,
                              decode_length + 1,
                              on_value=tf.cast(target_start_id, tf.int64))
    initial_logprob = tf.zeros(dtype=tf.float32, shape=[decode_length + 1])
    initial_i = tf.constant(0)

    initial_values = [initial_i, initial_ids, initial_logprob]

    _, decoded_ids, logprobs = tf.while_loop(loop_cond, inner_loop,
                                             initial_values)

    # Remove <START> symbol.
    decoded_ids = decoded_ids[1:]
    logprobs = logprobs[1:]
    # Sum logprobs to get scores for overall sequence.
    logprobs = tf.reduce_sum(logprobs, axis=0)

    # Expand decoded_ids and logprobs to reflect beam width dimension of 1.
    decoded_ids = tf.expand_dims(decoded_ids, 0)
    logprobs = tf.expand_dims(logprobs, 0)

    # This is the output dict that the function returns.
    output_decode_steps = decode_utils.get_decode_steps(
        decoded_ids, output_vocab_size, model_config)
    predictions = decode_utils.get_predictions(output_decode_steps)
    predictions[constants.SCORES_KEY] = logprobs

    return predictions
def _beam_decode(input_embeddings,
                 alpha,
                 output_vocab_size,
                 target_end_id,
                 target_start_id,
                 output_vocab_embeddings_table,
                 source_len,
                 model_config,
                 mode,
                 beam_size,
                 input_copy_mask=None,
                 clean_output_mask=None):
    """Beam search decoding."""
    # Assume batch size is 1.
    batch_size = 1
    encoder_output = common_layers.linear_transform(
        input_embeddings,
        output_size=model_config.model_parameters.encoder_dims,
        scope="bert_to_transformer")

    decode_length = model_config.data_options.max_decode_length

    # Expand decoder inputs to the beam width.
    input_embeddings = tf.tile(input_embeddings, [beam_size, 1, 1])
    encoder_output = tf.tile(encoder_output, [beam_size, 1, 1])

    def symbols_to_logits_fn(current_index, logit_indices):
        """Go from targets to logits.

    Args:
      current_index: Integer corresponding to 0-indexed decoder step.
      logit_indices: Tensor of shape [batch_size * beam_width, decode_length +
        1] to input to decoder.

    Returns:
      Tensor of shape [batch_size * beam_width, output_vocab_size] representing
      logits for the current decoder step.

    Raises:
      ValueError if inputs do not have static length.
    """
        decode_steps = decode_utils.get_decode_steps(logit_indices,
                                                     output_vocab_size,
                                                     model_config)
        target_embeddings = _get_target_embeddings(
            input_embeddings, output_vocab_embeddings_table, decode_steps,
            model_config)
        decoder_output = _build_transformer_decoder(
            encoder_output,
            source_len,
            target_embeddings,
            mode,
            model_config,
            single_step_index=current_index)
        logits = _get_action_logits(encoder_output,
                                    decoder_output,
                                    output_vocab_embeddings_table,
                                    output_vocab_size,
                                    model_config,
                                    input_copy_mask=input_copy_mask,
                                    clean_output_mask=clean_output_mask)
        # Squeeze length dimension, as it should be 1.
        logits = tf.squeeze(logits, axis=[1])
        # Shape of logits should now be:
        # [batch_size * beam_width, output_vocab_size].
        return logits

    initial_ids = tf.ones([batch_size], dtype=tf.int32) * target_start_id
    # ids has shape: [batch_size, beam_size, decode_length]
    # scores has shape: [batch_size, beam_size]
    decode_length = model_config.data_options.max_decode_length
    source_length = input_embeddings.get_shape()[1]

    if source_length.value is None:
        # Fall back on using dynamic shape information.
        source_length = tf.shape(input_embeddings)[1]
    extended_vocab_size = output_vocab_size + source_length
    ids, scores = beam_search.beam_search(symbols_to_logits_fn, initial_ids,
                                          beam_size, decode_length,
                                          extended_vocab_size, alpha,
                                          target_end_id, batch_size)
    # Remove start symbol from returned predicted IDs.
    predicted_ids = ids[:, :, 1:]
    # Since batch size is expected to be 1, squeeze the batch dimension.
    predicted_ids = tf.squeeze(predicted_ids, axis=[0])
    scores = tf.squeeze(scores, axis=[0])
    # This is the output dict that the function returns.
    output_decode_steps = decode_utils.get_decode_steps(
        predicted_ids, output_vocab_size, model_config)
    predictions = decode_utils.get_predictions(output_decode_steps)
    predictions[constants.SCORES_KEY] = scores
    return predictions