def call(self, inputs, mode="train"): """Implements call(). Args: inputs: a dictionary of tensors. mode: string, an enum for mode, train/eval. Returns: logits, decode_output_ids, output_log_probs for training. top_decoded_ids for eval. """ input_ids = inputs["input_ids"] input_mask = inputs["input_mask"] segment_ids = inputs["segment_ids"] all_encoder_outputs, _ = self.bert_layer( [input_ids, input_mask, segment_ids]) if mode not in ("train", "eval", "predict"): raise ValueError("Invalid call mode: %s" % mode) encoder_decoder_attention_bias = decoder.get_attention_bias( input_ids, bias_type="single_cross", padding_value=self.params.pad_token_id) if mode == "train": self_attention_bias = decoder.get_attention_bias( inputs["target_ids"], bias_type="decoder_self") decoder_inputs = dict( attention_bias=encoder_decoder_attention_bias, all_encoder_outputs=all_encoder_outputs, target_ids=inputs["target_ids"], self_attention_bias=self_attention_bias) decoder_outputs = self.decoder_layer(decoder_inputs) return self.train_decode(decoder_outputs) batch_size = tf.shape(input_ids)[0] start_token_ids = tf.ones([batch_size], tf.int32) * self.params.start_token_id # Add encoder output and attention bias to the cache. if self.params.use_cache: cache = self._init_cache(batch_size) else: cache = {} cache["all_encoder_outputs"] = all_encoder_outputs cache["attention_bias"] = encoder_decoder_attention_bias decoded_ids, scores = self.predict_decode(start_token_ids, cache) if mode == "predict": return decoded_ids[:, :self.params.beam_size, 1:], scores[:, :self.params.beam_size] decoder_inputs = dict( attention_bias=encoder_decoder_attention_bias, all_encoder_outputs=all_encoder_outputs) top_decoded_ids = decoded_ids[:, 0, 1:] return self._get_logits_for_decode_ids(decoder_inputs, top_decoded_ids)
def _get_symbols_to_logits_fn(self, max_decode_length): """Returns a decoding function that calculates logits of the next tokens.""" # Max decode length should be smaller than the positional embedding max # sequence length. decoder_self_attention_bias = decoder.get_attention_bias( input_tensor=None, bias_type="decoder_self", max_length=max_decode_length) def _symbols_to_logits_fn(ids, i, cache): """Generate logits for next candidate IDs.""" if self.params.use_cache: target_length = 1 else: target_length = i + 1 decoder_inputs = dict( doc_attention_probs=self._expand_doc_attention_probs( cache["doc_attention_probs"], target_length), all_encoder_outputs=cache["all_encoder_outputs"], attention_bias=cache["attention_bias"]) logits = self.get_decode_logits( decoder_inputs, ids, decoder_self_attention_bias, step=i, cache=cache if self.params.use_cache else None) return logits, cache return _symbols_to_logits_fn
def _get_logits_for_decode_ids(self, decoder_inputs, top_decoded_ids): """Returns the log probabilities for ids.""" target_ids = _add_sos_to_seq(top_decoded_ids, self.params.start_token_id) decoder_inputs["self_attention_bias"] = decoder.get_attention_bias( target_ids, bias_type="decoder_self") decoder_inputs["target_ids"] = target_ids decoder_outputs = self.decoder_layer(decoder_inputs) logits = embedding_linear(self.decoder_layer.embedding_lookup.embeddings, decoder_outputs) return logits
def _get_symbols_to_logits_fn(self, max_decode_length): """Returns a decoding function that calculates logits of the next tokens.""" # Max decode length should be smaller than the positional embedding max # sequence length. decoder_self_attention_bias = decoder.get_attention_bias( input_tensor=None, bias_type="decoder_self", max_length=max_decode_length) def _symbols_to_logits_fn(ids, i, cache): """Generate logits for next candidate IDs. Args: ids: Current decoded sequences. int tensor with shape [batch_size * beam_size, i + 1] i: Loop index cache: dictionary of values storing the encoder output, encoder-decoder attention bias, and previous decoder attention values. Returns: Tuple of (logits with shape [batch_size * beam_size, vocab_size], updated cache values) """ decoder_inputs = dict( all_encoder_outputs=cache["all_encoder_outputs"], attention_bias=cache["attention_bias"]) logits = self.get_decode_logits( decoder_inputs, ids, decoder_self_attention_bias, step=i, cache=cache if self.params.use_cache else None) return logits, cache return _symbols_to_logits_fn
def call(self, inputs, mode="training"): input_shape = tf_utils.get_shape_list(inputs["input_ids"], expected_rank=3) batch_size, num_docs, len_passage = (input_shape[0], input_shape[1], input_shape[2]) input_ids = tf.reshape(inputs["input_ids"], [-1, len_passage]) input_mask = tf.reshape(inputs["input_mask"], [-1, len_passage]) segment_ids = tf.reshape(inputs["segment_ids"], [-1, len_passage]) all_encoder_outputs, _ = self.bert_layer( [input_ids, input_mask, segment_ids]) encoder_outputs = tf.reshape( all_encoder_outputs[-1], [batch_size, num_docs, len_passage, self.params.hidden_size]) doc_attention_mask = tf.reshape( tf.cast( tf.math.count_nonzero(input_mask, axis=1, dtype=tf.int32) > 2, tf.int32), [batch_size, num_docs]) doc_attention_probs = self.doc_attention(encoder_outputs, doc_attention_mask) encoder_decoder_attention_bias = decoder.get_attention_bias( inputs["input_ids"], bias_type="multi_cross", padding_value=self.params.pad_token_id) if mode == "train": target_length = tf_utils.get_shape_list( inputs["target_ids"], expected_rank=2)[1] doc_attention_probs = self._expand_doc_attention_probs( doc_attention_probs, target_length) self_attention_bias = decoder.get_attention_bias( inputs["target_ids"], bias_type="decoder_self") decoder_inputs = dict( attention_bias=encoder_decoder_attention_bias, self_attention_bias=self_attention_bias, target_ids=inputs["target_ids"], all_encoder_outputs=encoder_outputs, doc_attention_probs=doc_attention_probs) decoder_outputs = self.decoder_layer(decoder_inputs) return self.train_decode(decoder_outputs) # Adds encoder output and attention bias to the cache. if self.params.use_cache: cache = self._init_cache(batch_size) else: cache = {} cache["all_encoder_outputs"] = [encoder_outputs] cache["attention_bias"] = encoder_decoder_attention_bias cache["doc_attention_probs"] = doc_attention_probs start_token_ids = tf.ones([batch_size], tf.int32) * self.params.start_token_id decoded_ids, scores = self.predict_decode(start_token_ids, cache) if mode == "predict": return decoded_ids[:, :self.params.beam_size, 1:], scores[:, :self.params.beam_size] top_decoded_ids = decoded_ids[:, 0, 1:] target_length = tf_utils.get_shape_list(top_decoded_ids)[-1] decoder_inputs = dict( attention_bias=encoder_decoder_attention_bias, all_encoder_outputs=[encoder_outputs], doc_attention_probs=self._expand_doc_attention_probs( doc_attention_probs, target_length)) return self._get_logits_for_decode_ids(decoder_inputs, top_decoded_ids)