Beispiel #1
0
  def call(self, inputs, mode="train"):
    """Implements call().

    Args:
      inputs: a dictionary of tensors.
      mode: string, an enum for mode, train/eval.

    Returns:
      logits, decode_output_ids, output_log_probs for training. top_decoded_ids
      for eval.
    """
    input_ids = inputs["input_ids"]
    input_mask = inputs["input_mask"]
    segment_ids = inputs["segment_ids"]
    all_encoder_outputs, _ = self.bert_layer(
        [input_ids, input_mask, segment_ids])

    if mode not in ("train", "eval", "predict"):
      raise ValueError("Invalid call mode: %s" % mode)
    encoder_decoder_attention_bias = decoder.get_attention_bias(
        input_ids,
        bias_type="single_cross",
        padding_value=self.params.pad_token_id)
    if mode == "train":
      self_attention_bias = decoder.get_attention_bias(
          inputs["target_ids"], bias_type="decoder_self")
      decoder_inputs = dict(
          attention_bias=encoder_decoder_attention_bias,
          all_encoder_outputs=all_encoder_outputs,
          target_ids=inputs["target_ids"],
          self_attention_bias=self_attention_bias)
      decoder_outputs = self.decoder_layer(decoder_inputs)
      return self.train_decode(decoder_outputs)

    batch_size = tf.shape(input_ids)[0]
    start_token_ids = tf.ones([batch_size],
                              tf.int32) * self.params.start_token_id
    # Add encoder output and attention bias to the cache.
    if self.params.use_cache:
      cache = self._init_cache(batch_size)
    else:
      cache = {}
    cache["all_encoder_outputs"] = all_encoder_outputs
    cache["attention_bias"] = encoder_decoder_attention_bias
    decoded_ids, scores = self.predict_decode(start_token_ids, cache)
    if mode == "predict":
      return decoded_ids[:, :self.params.beam_size,
                         1:], scores[:, :self.params.beam_size]

    decoder_inputs = dict(
        attention_bias=encoder_decoder_attention_bias,
        all_encoder_outputs=all_encoder_outputs)
    top_decoded_ids = decoded_ids[:, 0, 1:]
    return self._get_logits_for_decode_ids(decoder_inputs, top_decoded_ids)
Beispiel #2
0
    def _get_symbols_to_logits_fn(self, max_decode_length):
        """Returns a decoding function that calculates logits of the next tokens."""
        # Max decode length should be smaller than the positional embedding max
        # sequence length.
        decoder_self_attention_bias = decoder.get_attention_bias(
            input_tensor=None,
            bias_type="decoder_self",
            max_length=max_decode_length)

        def _symbols_to_logits_fn(ids, i, cache):
            """Generate logits for next candidate IDs."""
            if self.params.use_cache:
                target_length = 1
            else:
                target_length = i + 1
            decoder_inputs = dict(
                doc_attention_probs=self._expand_doc_attention_probs(
                    cache["doc_attention_probs"], target_length),
                all_encoder_outputs=cache["all_encoder_outputs"],
                attention_bias=cache["attention_bias"])
            logits = self.get_decode_logits(
                decoder_inputs,
                ids,
                decoder_self_attention_bias,
                step=i,
                cache=cache if self.params.use_cache else None)
            return logits, cache

        return _symbols_to_logits_fn
Beispiel #3
0
 def _get_logits_for_decode_ids(self, decoder_inputs, top_decoded_ids):
   """Returns the log probabilities for ids."""
   target_ids = _add_sos_to_seq(top_decoded_ids, self.params.start_token_id)
   decoder_inputs["self_attention_bias"] = decoder.get_attention_bias(
       target_ids, bias_type="decoder_self")
   decoder_inputs["target_ids"] = target_ids
   decoder_outputs = self.decoder_layer(decoder_inputs)
   logits = embedding_linear(self.decoder_layer.embedding_lookup.embeddings,
                             decoder_outputs)
   return logits
Beispiel #4
0
    def _get_symbols_to_logits_fn(self, max_decode_length):
        """Returns a decoding function that calculates logits of the next tokens."""
        # Max decode length should be smaller than the positional embedding max
        # sequence length.
        decoder_self_attention_bias = decoder.get_attention_bias(
            input_tensor=None,
            bias_type="decoder_self",
            max_length=max_decode_length)

        def _symbols_to_logits_fn(ids, i, cache):
            """Generate logits for next candidate IDs.

      Args:
        ids: Current decoded sequences. int tensor with shape [batch_size *
          beam_size, i + 1]
        i: Loop index
        cache: dictionary of values storing the encoder output, encoder-decoder
          attention bias, and previous decoder attention values.

      Returns:
        Tuple of
          (logits with shape [batch_size * beam_size, vocab_size],
           updated cache values)
      """
            decoder_inputs = dict(
                all_encoder_outputs=cache["all_encoder_outputs"],
                attention_bias=cache["attention_bias"])
            logits = self.get_decode_logits(
                decoder_inputs,
                ids,
                decoder_self_attention_bias,
                step=i,
                cache=cache if self.params.use_cache else None)
            return logits, cache

        return _symbols_to_logits_fn
Beispiel #5
0
    def call(self, inputs, mode="training"):
        input_shape = tf_utils.get_shape_list(inputs["input_ids"],
                                              expected_rank=3)
        batch_size, num_docs, len_passage = (input_shape[0], input_shape[1],
                                             input_shape[2])
        input_ids = tf.reshape(inputs["input_ids"], [-1, len_passage])
        input_mask = tf.reshape(inputs["input_mask"], [-1, len_passage])
        segment_ids = tf.reshape(inputs["segment_ids"], [-1, len_passage])
        all_encoder_outputs, _ = self.bert_layer(
            [input_ids, input_mask, segment_ids])
        encoder_outputs = tf.reshape(
            all_encoder_outputs[-1],
            [batch_size, num_docs, len_passage, self.params.hidden_size])
        doc_attention_mask = tf.reshape(
            tf.cast(
                tf.math.count_nonzero(input_mask, axis=1, dtype=tf.int32) > 2,
                tf.int32), [batch_size, num_docs])

        doc_attention_probs = self.doc_attention(encoder_outputs,
                                                 doc_attention_mask)
        encoder_decoder_attention_bias = decoder.get_attention_bias(
            inputs["input_ids"],
            bias_type="multi_cross",
            padding_value=self.params.pad_token_id)

        if mode == "train":
            target_length = tf_utils.get_shape_list(inputs["target_ids"],
                                                    expected_rank=2)[1]
            doc_attention_probs = self._expand_doc_attention_probs(
                doc_attention_probs, target_length)
            self_attention_bias = decoder.get_attention_bias(
                inputs["target_ids"], bias_type="decoder_self")
            decoder_inputs = dict(
                attention_bias=encoder_decoder_attention_bias,
                self_attention_bias=self_attention_bias,
                target_ids=inputs["target_ids"],
                all_encoder_outputs=encoder_outputs,
                doc_attention_probs=doc_attention_probs)
            decoder_outputs = self.decoder_layer(decoder_inputs)
            return self.train_decode(decoder_outputs)

        # Adds encoder output and attention bias to the cache.
        if self.params.use_cache:
            cache = self._init_cache(batch_size)
        else:
            cache = {}
        cache["all_encoder_outputs"] = [encoder_outputs]
        cache["attention_bias"] = encoder_decoder_attention_bias
        cache["doc_attention_probs"] = doc_attention_probs

        start_token_ids = tf.ones([batch_size],
                                  tf.int32) * self.params.start_token_id
        decoded_ids, scores = self.predict_decode(start_token_ids, cache)
        if mode == "predict":
            return decoded_ids[:, :self.params.beam_size,
                               1:], scores[:, :self.params.beam_size]

        top_decoded_ids = decoded_ids[:, 0, 1:]
        target_length = tf_utils.get_shape_list(top_decoded_ids)[-1]
        decoder_inputs = dict(
            attention_bias=encoder_decoder_attention_bias,
            all_encoder_outputs=[encoder_outputs],
            doc_attention_probs=self._expand_doc_attention_probs(
                doc_attention_probs, target_length))
        return self._get_logits_for_decode_ids(decoder_inputs, top_decoded_ids)