Пример #1
0
def generic_loss(top_out,
                 targets,
                 model_hparams,
                 weights_fn=common_layers.weights_nonzero,
                 gaussian=False):
    logits = top_out
    labels = targets

    # padded_cross_entropy
    confidence = 1.0 - model_hparams.label_smoothing
    logits_shape = common_layers.shape_list(logits)
    vocab_size = logits_shape[-1]

    logits, labels = common_layers.pad_with_zeros(logits, labels)
    logits = tf.reshape(logits,
                        common_layers.shape_list(labels) + [vocab_size],
                        name="padded_cross_entropy_size_check")
    logits = tf.cast(logits, tf.float32)
    xent = common_layers.smoothing_cross_entropy(logits,
                                                 labels,
                                                 vocab_size,
                                                 confidence,
                                                 gaussian=gaussian)
    weights = weights_fn(labels)
    return tf.reduce_sum(xent * weights), tf.reduce_sum(weights)
Пример #2
0
        def while_exit_cond(result, logits, loss):  # pylint: disable=unused-argument
            """Exit the loop either if reach decode_length or EOS."""
            length = common_layers.shape_list(result)[1]

            not_overflow = length < decode_length

            if self._problem_hparams.stop_at_eos:

                def fn_not_eos():
                    return tf.not_equal(  # Check if the last predicted element is a EOS
                        tf.squeeze(result[:, -1, :, :]), text_encoder.EOS_ID)

                not_eos = tf.cond(
                    # We only check for early stopping if there is at least 1 element (
                    # otherwise not_eos will crash).
                    tf.not_equal(length, 0),
                    fn_not_eos,
                    lambda: True,
                )

                return tf.cond(
                    tf.equal(batch_size, 1),
                    # If batch_size == 1, we check EOS for early stopping.
                    lambda: tf.logical_and(not_overflow, not_eos),
                    # Else, just wait for max length
                    lambda: not_overflow)
            return not_overflow
Пример #3
0
def combine_last_two_dimensions(x):
    """Reshape x so that the last two dimension become one.

  Args:
    x: a Tensor with shape [..., a, b]

  Returns:
    a Tensor with shape [..., ab]
  """
    x_shape = common_layers.shape_list(x)
    a, b = x_shape[-2:]
    return tf.reshape(x, x_shape[:-2] + [a * b])
  def add_position_timing_signal(x, hparams):
    """Add n-dimensional embedding as the position (horizontal) timing signal.

    Args:
      x: a tensor with shape [batch, length, depth]
      step: step
      hparams: model hyper parameters

    Returns:
      a Tensor with the same shape as x.

    """

    length = common_layers.shape_list(x)[1]
    channels = common_layers.shape_list(x)[2]
    signal = common_attention.get_timing_signal_1d(length, channels)

    if hparams.add_or_concat_timing_signal == "add":
      x_with_timing = x + common_layers.cast_like(signal, x)
      return x_with_timing

    else:
      ValueError("Unknown timing signal add or concat type: %s"
                 % hparams.add_or_concat_timing_signal)
Пример #5
0
def split_last_dimension(x, n):
    """Reshape x so that the last dimension becomes two dimensions.

  The first of these two dimensions is n.

  Args:
    x: a Tensor with shape [..., m]
    n: an integer.

  Returns:
    a Tensor with shape [..., n, m/n]
  """
    x_shape = common_layers.shape_list(x)
    m = x_shape[-1]
    if isinstance(m, int) and isinstance(n, int):
        assert m % n == 0
    return tf.reshape(x, x_shape[:-1] + [n, m // n])
Пример #6
0
def transformer_prepare_decoder(targets):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well. This is
      needed now for "packed" datasets.

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in decoder self-attention
  """

    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(
            common_layers.shape_list(targets)[1]))

    decoder_input = common_layers.shift_right_3d(targets)

    return (decoder_input, decoder_self_attention_bias)
Пример #7
0
def symbol_top(body_output, target_embedding_space):
    body_output_shape = common_layers.shape_list(body_output)
    body_output = tf.reshape(body_output, [-1, body_output_shape[-1]])
    logits = tf.matmul(body_output, target_embedding_space, transpose_b=True)
    return tf.reshape(
        logits, body_output_shape[:-1] + [1, target_embedding_space.shape[0]])
Пример #8
0
    def _greedy_infer(self, features):

        features["inputs"] = tf.expand_dims(features["inputs"], 2)
        batch_size = common_layers.shape_list(features["inputs"])[0]
        initial_output = tf.zeros((batch_size, 0, 1, 1), dtype=tf.int64)

        prefix_length = common_layers.shape_list(features["inputs"])[1]
        decode_length = prefix_length + self._hparams.decode_length
        result = initial_output

        vocab_size = self._problem_hparams.vocab_size["targets"]

        logits = tf.zeros((batch_size, 0, 1, 1, vocab_size))
        logits_shape_inv = [None, None, None, None, None]

        loss = 0.0

        def infer_step(recent_output, recent_logits, unuserd_loss):
            padded = tf.pad(recent_output, [[0, 0], [0, 1], [0, 0], [0, 0]])
            features["targets"] = padded
            samples, logits, loss = self.sample(features)

            cur_sample = samples[:, -1, :, :]
            cur_sample = tf.to_int64(tf.expand_dims(cur_sample, axis=1))
            samples = tf.concat([recent_output, cur_sample], axis=1)

            logits = tf.concat([recent_logits, logits[:, -1:]], 1)

            return samples, logits, loss

        def while_exit_cond(result, logits, loss):  # pylint: disable=unused-argument
            """Exit the loop either if reach decode_length or EOS."""
            length = common_layers.shape_list(result)[1]

            not_overflow = length < decode_length

            if self._problem_hparams.stop_at_eos:

                def fn_not_eos():
                    return tf.not_equal(  # Check if the last predicted element is a EOS
                        tf.squeeze(result[:, -1, :, :]), text_encoder.EOS_ID)

                not_eos = tf.cond(
                    # We only check for early stopping if there is at least 1 element (
                    # otherwise not_eos will crash).
                    tf.not_equal(length, 0),
                    fn_not_eos,
                    lambda: True,
                )

                return tf.cond(
                    tf.equal(batch_size, 1),
                    # If batch_size == 1, we check EOS for early stopping.
                    lambda: tf.logical_and(not_overflow, not_eos),
                    # Else, just wait for max length
                    lambda: not_overflow)
            return not_overflow

        result, logits, loss = tf.while_loop(
            while_exit_cond,
            infer_step,
            [result, logits, loss],
            shape_invariants=[
                tf.TensorShape([None, None, None, None]),
                tf.TensorShape(logits_shape_inv),
                tf.TensorShape([]),
            ],
            back_prop=False,
            parallel_iterations=1,
            name='',
        )

        return result