Ejemplo n.º 1
0
    def test_sequence_beam_search(self, padded_decode):
        # batch_size*beam_size, max_decode_length, vocab_size
        probabilities = tf.constant([[[0.2, 0.7, 0.1], [0.5, 0.3, 0.2],
                                      [0.1, 0.8, 0.1]],
                                     [[0.1, 0.8, 0.1], [0.3, 0.4, 0.3],
                                      [0.2, 0.1, 0.7]]])
        # batch_size, max_decode_length, num_heads, embed_size per head
        x = tf.zeros([1, 3, 2, 32], dtype=tf.float32)
        cache = {'layer_%d' % layer: {'k': x, 'v': x} for layer in range(2)}

        def _get_test_symbols_to_logits_fn():
            """Test function that returns logits for next token."""
            def symbols_to_logits_fn(_, i, cache):
                logits = tf.cast(probabilities[:, i, :], tf.float32)
                return logits, cache

            return symbols_to_logits_fn

        predictions, _ = beam_search.sequence_beam_search(
            symbols_to_logits_fn=_get_test_symbols_to_logits_fn(),
            initial_ids=tf.zeros([1], dtype=tf.int32),
            initial_cache=cache,
            vocab_size=3,
            beam_size=2,
            alpha=0.6,
            max_decode_length=3,
            eos_id=9,
            padded_decode=padded_decode,
            dtype=tf.float32)
        self.assertAllEqual([[[0, 1, 0, 1], [0, 1, 1, 2]]], predictions)
Ejemplo n.º 2
0
 def predict_decode(self, start_token_ids, cache):
   symbols_to_logits_fn = self._get_symbols_to_logits_fn(self.params.len_title)
   # Use beam search to find the top beam_size sequences and scores.
   decoded_ids, scores = beam_search.sequence_beam_search(
       symbols_to_logits_fn=symbols_to_logits_fn,
       initial_ids=start_token_ids,
       initial_cache=cache,
       vocab_size=self.params.vocab_size,
       beam_size=self.params.beam_size,
       alpha=self.params.alpha,
       max_decode_length=self.params.len_title,
       padded_decode=self.params.get("padded_decode", False),
       eos_id=self.params.end_token_id)
   return decoded_ids, scores
    def call(self, inputs):
        """Calculate target logits or inferred target sequences.

    Args:
      inputs: a dictionary of tensors.
        Feature `inputs` (optional): int tensor with shape
          `[batch_size, input_length]`.
        Feature `embedded_inputs` (optional): float tensor with shape
          `[batch_size, input_length, embedding_width]`.
        Feature `targets` (optional): None or int tensor with shape
          `[batch_size, target_length]`.
        Feature `input_masks` (optional): When providing the `embedded_inputs`,
          the dictionary must provide a boolean mask marking the filled time
          steps. The shape of the tensor is `[batch_size, input_length]`.
        Either `inputs` or `embedded_inputs` and `input_masks` must be present
        in the input dictionary. In the second case the projection of the
        integer tokens to the transformer embedding space is skipped and
        `input_masks` is expected to be present.

    Returns:
      If targets is defined, then return logits for each word in the target
      sequence, which is a float tensor with shape
      `(batch_size, target_length, vocab_size)`. If target is `None`, then
      generate output sequence one token at a time and
      returns a dictionary {
          outputs: `(batch_size, decoded_length)`
          scores: `(batch_size, 1)`}
      Even when `float16` is used, the output tensor(s) are always `float32`.

    Raises:
      NotImplementedError: If try to use padded decode method on CPU/GPUs.
    """
        # Prepare inputs to the layer stack by adding positional encodings and
        # applying dropout.
        targets = inputs.get("targets", None)
        (embedded_inputs, boolean_mask, input_shape,
         source_dtype) = self._parse_inputs(inputs)
        embedding_mask = tf.cast(boolean_mask, embedded_inputs.dtype)
        embedded_inputs *= tf.expand_dims(embedding_mask, -1)
        # Attention_mask generation.
        attention_mask = tf.cast(tf.reshape(
            boolean_mask, [input_shape[0], 1, input_shape[1]]),
                                 dtype=source_dtype)
        broadcast_ones = tf.ones(shape=[input_shape[0], input_shape[1], 1],
                                 dtype=source_dtype)
        attention_mask = broadcast_ones * attention_mask

        pos_encoding = self.position_embedding(embedded_inputs)
        pos_encoding = tf.cast(pos_encoding, embedded_inputs.dtype)
        encoder_inputs = embedded_inputs + pos_encoding

        encoder_inputs = self.encoder_dropout(encoder_inputs)

        encoder_outputs = self.encoder_layer(encoder_inputs,
                                             attention_mask=attention_mask)

        if targets is None:
            if self._padded_decode:
                max_decode_length = self._decode_max_length
            else:
                max_decode_length = self._decode_max_length or (
                    tf.shape(encoder_outputs)[1] + self._extra_decode_length)
            symbols_to_logits_fn = self._get_symbols_to_logits_fn(
                max_decode_length)

            batch_size = tf.shape(encoder_outputs)[0]
            # Create initial set of IDs that will be passed to symbols_to_logits_fn.
            initial_ids = tf.zeros([batch_size], dtype=tf.int32)

            # Create cache storing decoder attention values for each layer.
            init_decode_length = (max_decode_length
                                  if self._padded_decode else 0)
            num_heads = self.decoder_layer.num_attention_heads
            dim_per_head = self._embedding_width // num_heads

            # Cache dtype needs to match beam_search dtype.
            # pylint: disable=g-complex-comprehension
            cache = {
                str(layer): {
                    "key":
                    tf.zeros([
                        batch_size, init_decode_length, num_heads, dim_per_head
                    ],
                             dtype=self.compute_dtype),
                    "value":
                    tf.zeros([
                        batch_size, init_decode_length, num_heads, dim_per_head
                    ],
                             dtype=self.compute_dtype)
                }
                for layer in range(self.decoder_layer.num_layers)
            }
            # pylint: enable=g-complex-comprehension

            # Add encoder output and attention bias to the cache.
            encoder_outputs = tf.cast(encoder_outputs,
                                      dtype=self.compute_dtype)
            attention_mask = tf.cast(tf.reshape(
                boolean_mask, [input_shape[0], 1, input_shape[1]]),
                                     dtype=self.compute_dtype)
            cache["encoder_outputs"] = encoder_outputs
            cache["encoder_decoder_attention_mask"] = attention_mask

            # Use beam search to find the top beam_size sequences and scores.
            decoded_ids, scores = beam_search.sequence_beam_search(
                symbols_to_logits_fn=symbols_to_logits_fn,
                initial_ids=initial_ids,
                initial_cache=cache,
                vocab_size=self._vocab_size,
                beam_size=self._beam_size,
                alpha=self._alpha,
                max_decode_length=max_decode_length,
                eos_id=self._eos_id,
                padded_decode=self._padded_decode,
                dtype=self.compute_dtype)

            # Get the top sequence for each batch element
            top_decoded_ids = decoded_ids[:, 0, 1:]
            top_scores = scores[:, 0]

            return {"outputs": top_decoded_ids, "scores": top_scores}

        # Shift targets to the right, and remove the last element
        targets = tf.pad(targets, [[0, 0], [1, 0]])[:, :-1]
        decoder_inputs = self.embedding_lookup(targets)
        length = tf.shape(decoder_inputs)[1]
        pos_encoding = self.position_embedding(decoder_inputs)
        pos_encoding = tf.cast(pos_encoding, embedded_inputs.dtype)
        decoder_inputs += pos_encoding

        decoder_inputs = self.decoder_dropout(decoder_inputs)

        decoder_shape = tf_utils.get_shape_list(decoder_inputs,
                                                expected_rank=3)
        batch_size = decoder_shape[0]
        decoder_length = decoder_shape[1]

        self_attention_mask = tf.linalg.band_part(tf.ones([length, length]),
                                                  -1, 0)
        self_attention_mask = tf.reshape(self_attention_mask,
                                         [1, length, length])
        self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1])

        attention_mask = tf.cast(tf.expand_dims(boolean_mask, axis=1),
                                 dtype=source_dtype)
        attention_mask = tf.tile(attention_mask, [1, decoder_length, 1])

        outputs = self.decoder_layer(decoder_inputs,
                                     encoder_outputs,
                                     self_attention_mask=self_attention_mask,
                                     cross_attention_mask=attention_mask)
        logits = self._embedding_linear(self.embedding_lookup.embeddings,
                                        outputs)
        # Model outputs should be float32 to avoid numeric issues.
        # https://www.tensorflow.org/guide/mixed_precision#building_the_model
        logits = tf.cast(logits, tf.float32)
        return logits
Ejemplo n.º 4
0
    def predict(self, encoder_outputs, encoder_decoder_attention_bias,
                training):
        """Return predicted sequence."""
        encoder_outputs = tf.cast(encoder_outputs, self.params["dtype"])
        if self.params["padded_decode"]:
            batch_size = encoder_outputs.shape.as_list()[0]
            input_length = encoder_outputs.shape.as_list()[1]
        else:
            batch_size = tf.shape(encoder_outputs)[0]
            input_length = tf.shape(encoder_outputs)[1]
        max_decode_length = input_length + self.params["extra_decode_length"]
        encoder_decoder_attention_bias = tf.cast(
            encoder_decoder_attention_bias, self.params["dtype"])

        symbols_to_logits_fn = self._get_symbols_to_logits_fn(
            max_decode_length, training)

        # Create initial set of IDs that will be passed into symbols_to_logits_fn.
        initial_ids = tf.zeros([batch_size], dtype=tf.int32)

        # Create cache storing decoder attention values for each layer.
        # pylint: disable=g-complex-comprehension
        init_decode_length = (max_decode_length
                              if self.params["padded_decode"] else 0)
        num_heads = self.params["num_heads"]
        dim_per_head = self.params["hidden_size"] // num_heads
        cache = {
            "layer_%d" % layer: {
                "k":
                tf.zeros(
                    [batch_size, init_decode_length, num_heads, dim_per_head],
                    dtype=self.params["dtype"]),
                "v":
                tf.zeros(
                    [batch_size, init_decode_length, num_heads, dim_per_head],
                    dtype=self.params["dtype"])
            }
            for layer in range(self.params["num_hidden_layers"])
        }
        # pylint: enable=g-complex-comprehension

        # Add encoder output and attention bias to the cache.
        cache["encoder_outputs"] = encoder_outputs
        cache[
            "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias

        # Use beam search to find the top beam_size sequences and scores.
        decoded_ids, scores = beam_search.sequence_beam_search(
            symbols_to_logits_fn=symbols_to_logits_fn,
            initial_ids=initial_ids,
            initial_cache=cache,
            vocab_size=self.params["vocab_size"],
            beam_size=self.params["beam_size"],
            alpha=self.params["alpha"],
            max_decode_length=max_decode_length,
            eos_id=EOS_ID,
            padded_decode=self.params["padded_decode"],
            dtype=self.params["dtype"])

        # Get the top sequence for each batch element
        top_decoded_ids = decoded_ids[:, 0, 1:]
        top_scores = scores[:, 0]

        return {"outputs": top_decoded_ids, "scores": top_scores}
Ejemplo n.º 5
0
    def call(self, inputs):
        """Calculate target logits or inferred target sequences.

    Args:
      inputs: input tensor list of size 1 or 2.
        First item, inputs: int tensor with shape [batch_size, input_length].
        Second item (optional), targets: None or int tensor with shape
          [batch_size, target_length].

    Returns:
      If targets is defined, then return logits for each word in the target
      sequence. float tensor with shape [batch_size, target_length, vocab_size]
      If target is none, then generate output sequence one token at a time.
        returns a dictionary {
          outputs: [batch_size, decoded length]
          scores: [batch_size, float]}
      Even when float16 is used, the output tensor(s) are always float32.

    Raises:
      NotImplementedError: If try to use padded decode method on CPU/GPUs.
    """
        if len(inputs) == 2:
            sources, targets = inputs[0], inputs[1]
        else:
            # Decoding path.
            sources, targets = inputs[0], None

        attention_bias = model_utils.get_padding_bias(sources)
        attention_bias = tf.cast(attention_bias, self._dtype)
        # Prepare inputs to the layer stack by adding positional encodings and
        # applying dropout.
        embedded_inputs = self.embedding_lookup(sources)
        embedding_mask = tf.cast(tf.not_equal(sources, 0),
                                 self.embedding_lookup.embeddings.dtype)
        embedded_inputs *= tf.expand_dims(embedding_mask, -1)
        embedded_inputs = tf.cast(embedded_inputs, self._dtype)
        # Attention_mask generation.
        input_shape = tf_utils.get_shape_list(sources, expected_rank=2)
        attention_mask = tf.cast(tf.reshape(tf.not_equal(
            sources, 0), [input_shape[0], 1, input_shape[1]]),
                                 dtype=sources.dtype)
        broadcast_ones = tf.ones(shape=[input_shape[0], input_shape[1], 1],
                                 dtype=sources.dtype)
        attention_mask = broadcast_ones * attention_mask

        pos_encoding = self.position_embedding(inputs=embedded_inputs)
        pos_encoding = tf.cast(pos_encoding, self._dtype)
        encoder_inputs = embedded_inputs + pos_encoding

        encoder_inputs = self.encoder_dropout(encoder_inputs)

        encoder_outputs = self.encoder_layer(encoder_inputs,
                                             attention_mask=attention_mask)

        if targets is None:
            encoder_decoder_attention_bias = attention_bias
            encoder_outputs = tf.cast(encoder_outputs, self._dtype)
            if self._padded_decode:
                batch_size = encoder_outputs.shape.as_list()[0]
                max_decode_length = self._decode_max_length
            else:
                batch_size = tf.shape(encoder_outputs)[0]
                max_decode_length = self._decode_max_length or (
                    tf.shape(encoder_outputs)[1] + self._extra_decode_length)
            encoder_decoder_attention_bias = tf.cast(
                encoder_decoder_attention_bias, self._dtype)

            symbols_to_logits_fn = self._get_symbols_to_logits_fn(
                max_decode_length)

            # Create initial set of IDs that will be passed to symbols_to_logits_fn.
            initial_ids = tf.zeros([batch_size], dtype=tf.int32)

            # Create cache storing decoder attention values for each layer.
            # pylint: disable=g-complex-comprehension
            init_decode_length = (max_decode_length
                                  if self._padded_decode else 0)
            num_heads = self.decoder_layer.num_attention_heads
            dim_per_head = self._embedding_width // num_heads

            cache = {
                str(layer): {
                    "key":
                    tf.zeros([
                        batch_size, init_decode_length, num_heads, dim_per_head
                    ],
                             dtype=self._dtype),
                    "value":
                    tf.zeros([
                        batch_size, init_decode_length, num_heads, dim_per_head
                    ],
                             dtype=self._dtype)
                }
                for layer in range(self.decoder_layer.num_layers)
            }

            # pylint: enable=g-complex-comprehension
            # Add encoder output and attention bias to the cache.
            cache["encoder_outputs"] = encoder_outputs
            cache[
                "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias

            # Use beam search to find the top beam_size sequences and scores.
            decoded_ids, scores = beam_search.sequence_beam_search(
                symbols_to_logits_fn=symbols_to_logits_fn,
                initial_ids=initial_ids,
                initial_cache=cache,
                vocab_size=self._vocab_size,
                beam_size=self._beam_size,
                alpha=self._alpha,
                max_decode_length=max_decode_length,
                eos_id=EOS_ID,
                padded_decode=self._padded_decode,
                dtype=self._dtype)

            # Get the top sequence for each batch element
            top_decoded_ids = decoded_ids[:, 0, 1:]
            top_scores = scores[:, 0]

            return {"outputs": top_decoded_ids, "scores": top_scores}

        decoder_inputs = self.embedding_lookup(targets)
        embedding_mask = tf.cast(tf.not_equal(targets, 0),
                                 self.embedding_lookup.embeddings.dtype)
        decoder_inputs *= tf.expand_dims(embedding_mask, -1)
        decoder_inputs = tf.cast(decoder_inputs, self._dtype)
        # Shift targets to the right, and remove the last element
        decoder_inputs = tf.pad(decoder_inputs,
                                [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
        length = tf.shape(decoder_inputs)[1]
        pos_encoding = self.position_embedding(decoder_inputs)
        pos_encoding = tf.cast(pos_encoding, self._dtype)
        decoder_inputs += pos_encoding

        decoder_inputs = self.decoder_dropout(decoder_inputs)

        decoder_shape = tf_utils.get_shape_list(decoder_inputs,
                                                expected_rank=3)
        batch_size = decoder_shape[0]
        decoder_length = decoder_shape[1]

        self_attention_mask = tf.linalg.band_part(
            tf.ones([length, length], dtype=tf.float32), -1, 0)
        self_attention_mask = tf.reshape(self_attention_mask,
                                         [1, length, length])
        self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1])

        attention_mask = tf.cast(tf.expand_dims(tf.not_equal(sources, 0),
                                                axis=1),
                                 dtype=sources.dtype)
        attention_mask = tf.tile(attention_mask, [1, decoder_length, 1])

        outputs = self.decoder_layer(decoder_inputs,
                                     encoder_outputs,
                                     memory_mask=self_attention_mask,
                                     target_mask=attention_mask)
        logits = self._embedding_linear(self.embedding_lookup.embeddings,
                                        outputs)
        logits = tf.cast(logits, tf.float32)
        return logits