def _get_top_k(scores1, scores2, k, max_span_size, support2question): max_support_length = tf.shape(scores1)[1] doc_idx, pointer1, topk_scores1 = segment_top_k(scores1, support2question, k) # [num_questions * topk] doc_idx_flat = tf.reshape(doc_idx, [-1]) pointer_flat1 = tf.reshape(pointer1, [-1]) # [num_questions * topk, support_length] scores_gathered2 = tf.gather(scores2, doc_idx_flat) if max_span_size < 0: pointer_flat1, max_span_size = pointer_flat1 + max_span_size + 1, -max_span_size left_mask = misc.mask_for_lengths(tf.cast(pointer_flat1, tf.int32), max_support_length, mask_right=False) right_mask = misc.mask_for_lengths( tf.cast(pointer_flat1 + max_span_size, tf.int32), max_support_length) scores_gathered2 = scores_gathered2 + left_mask + right_mask pointer2 = tf.argmax(scores_gathered2, axis=1, output_type=tf.int32) topk_score2 = tf.gather_nd(scores2, tf.stack([doc_idx_flat, pointer2], 1)) return doc_idx, pointer1, tf.reshape( pointer2, [-1, k]), topk_scores1 + tf.reshape(topk_score2, [-1, k])
def eval(): # [num_questions * topk, support_length] left_mask = misc.mask_for_lengths(tf.cast(start_pointer, tf.int32), max_support_length, mask_right=False) right_mask = misc.mask_for_lengths(tf.cast(start_pointer + max_span_size, tf.int32), max_support_length) masked_end_scores = end_scores + left_mask + right_mask predicted_ends = tf.argmax(masked_end_scores, axis=1, output_type=tf.int32) return (start_scores, masked_end_scores, tf.gather(doc_idx_for_support, doc_idx_flat), start_pointer, predicted_ends)
def bilinear_answer_layer(size, encoded_question, question_length, encoded_support, support_length, support2question, answer2support, is_eval, topk=1, max_span_size=10000): """Answer layer for multiple paragraph QA.""" # computing single time attention over question size = encoded_support.get_shape()[-1].value question_state = compute_question_state(encoded_question, question_length) # compute logits hidden = tf.gather( tf.layers.dense(question_state, 2 * size, name="hidden"), support2question) hidden_start, hidden_end = tf.split(hidden, 2, 1) support_mask = misc.mask_for_lengths(support_length) start_scores = tf.einsum('ik,ijk->ij', hidden_start, encoded_support) start_scores = start_scores + support_mask end_scores = tf.einsum('ik,ijk->ij', hidden_end, encoded_support) end_scores = end_scores + support_mask return compute_spans(start_scores, end_scores, answer2support, is_eval, support2question, topk, max_span_size)
def mlp_answer_layer(size, encoded_question, question_length, encoded_support, support_length, support2question, answer2support, is_eval, topk=1, max_span_size=10000): """Answer layer for multiple paragraph QA.""" # computing single time attention over question question_state = compute_question_state(encoded_question, question_length) # compute logits static_input = tf.concat([tf.gather(tf.expand_dims(question_state, 1), support2question) * encoded_support, encoded_support], 2) hidden = tf.gather(tf.layers.dense(question_state, 2 * size, name="hidden_1"), support2question) hidden = tf.layers.dense( static_input, 2 * size, use_bias=False, name="hidden_2") + tf.expand_dims(hidden, 1) hidden_start, hidden_end = tf.split(tf.nn.relu(hidden), 2, 2) support_mask = misc.mask_for_lengths(support_length) start_scores = tf.layers.dense(hidden_start, 1, use_bias=False, name="start_scores") start_scores = tf.squeeze(start_scores, [2]) start_scores = start_scores + support_mask end_scores = tf.layers.dense(hidden_end, 1, use_bias=False, name="end_scores") end_scores = tf.squeeze(end_scores, [2]) end_scores = end_scores + support_mask return compute_spans(start_scores, end_scores, answer2support, is_eval, support2question, topk, max_span_size)
def non_zero_batchsize_op(): max_length = tf.shape(seq)[1] encoded = tf.nn.embedding_lookup(ctxt_word_embeddings, seq) one_hot = [0.0] * num_sequences one_hot[i] = 1.0 mode_feature = tf.constant([[one_hot]], tf.float32) mode_feature = tf.tile(mode_feature, tf.stack([num_seq, max_length, 1])) encoded = tf.concat([encoded, mode_feature], 2) encoded = modular_encoder.modular_encoder( sequence_module, {'text': encoded}, {'text': length}, {'text': None}, size, 1.0 - keep_prob, is_eval)[0]['text'] mask = misc.mask_for_lengths(length, max_length, mask_right=False, value=1.0) encoded = encoded * tf.expand_dims(mask, 2) seq_lemmas = tf.gather(word2lemma_off, tf.reshape(seq, [-1])) new_lemma_embeddings = tf.unsorted_segment_max( tf.reshape(encoded, [-1, size]), seq_lemmas, tf.reduce_max(word2lemma_off) + 1) new_lemma_embeddings = tf.nn.relu(new_lemma_embeddings) return tf.gather(new_lemma_embeddings, word2lemma_off)
def apply_attention(attn_scores, states, length, is_self=False, with_sentinel=True, reuse=False, seq2_to_seq1=None): attn_scores += tf.expand_dims( misc.mask_for_lengths(length, tf.shape(attn_scores)[2]), 1) softmax = tf.nn.softmax if seq2_to_seq1 is None else lambda x: segment.segment_softmax( x, seq2_to_seq1) if is_self: # exclude attending to state itself attn_scores += tf.expand_dims( tf.diag(tf.fill([tf.shape(attn_scores)[1]], -1e6)), 0) if with_sentinel: with tf.variable_scope('sentinel', reuse=reuse): s = tf.get_variable('score', [1, 1, 1], tf.float32, tf.zeros_initializer()) s = tf.tile(s, [tf.shape(attn_scores)[0], tf.shape(attn_scores)[1], 1]) attn_probs = softmax(tf.concat([s, attn_scores], 2)) attn_probs = attn_probs[:, :, 1:] else: attn_probs = softmax(attn_scores) attn_states = tf.einsum('abd,adc->abc', attn_probs, states) if seq2_to_seq1 is not None: attn_states = tf.unsorted_segment_sum(attn_states, seq2_to_seq1, tf.reduce_max(seq2_to_seq1) + 1) return attn_scores, attn_probs, attn_states
def compute_question_state(encoded_question, question_length): attention_scores = tf.layers.dense(encoded_question, 1, name="question_attention") q_mask = misc.mask_for_lengths(question_length) attention_scores = attention_scores + tf.expand_dims(q_mask, 2) question_attention_weights = tf.nn.softmax(attention_scores, 1, name="question_attention_weights") question_state = tf.reduce_sum(question_attention_weights * encoded_question, 1) return question_state
def bidaf_answer_layer(encoded_support_start, encoded_support_end, support_length, support2question, answer2support, is_eval, topk=1, max_span_size=10000): # BiLSTM(M) = M^2 = encoded_support_end start_scores = tf.squeeze(tf.layers.dense(encoded_support_start, 1, use_bias=False), 2) end_scores = tf.squeeze(tf.layers.dense(encoded_support_end, 1, use_bias=False), 2) # mask out-of-bounds slots by adding -1000 support_mask = misc.mask_for_lengths(support_length) start_scores = start_scores + support_mask end_scores = end_scores + support_mask return compute_spans(start_scores, end_scores, answer2support, is_eval, support2question, topk=topk, max_span_size=max_span_size)
def conv_char_embedding(num_chars, repr_dim, word_chars, word_lengths, word_sequences=None, conv_width=5, emb_initializer=tf.random_normal_initializer(0.0, 0.1), scope=None): """Build simple convolutional character based embeddings for words with a fixed filter and size. After the convolution max-pooling over characters is employed for each filter. If word sequences are given, these will be embedded with the newly created embeddings. """ # "fixed PADDING on character level" pad = tf.zeros(tf.stack([tf.shape(word_lengths)[0], conv_width // 2]), tf.int32) word_chars = tf.concat([pad, word_chars, pad], 1) with tf.variable_scope(scope or "char_embeddings"): char_embedding_matrix = \ tf.get_variable("char_embedding_matrix", shape=(num_chars, repr_dim), initializer=emb_initializer, trainable=True) max_word_length = tf.reduce_max(word_lengths) embedded_chars = tf.nn.embedding_lookup(char_embedding_matrix, tf.cast(word_chars, tf.int32)) with tf.variable_scope("conv"): # create filter like this to get fan-in and fan-out right for initializers depending on those filter = tf.get_variable("filter", [conv_width * repr_dim, repr_dim]) filter_reshaped = tf.reshape(filter, [conv_width, repr_dim, repr_dim]) # [B, T, S + pad_right] conv_out = tf.nn.conv1d(embedded_chars, filter_reshaped, 1, "VALID") conv_mask = tf.expand_dims( misc.mask_for_lengths(word_lengths, max_length=max_word_length), 2) conv_out = conv_out + conv_mask embedded_words = tf.reduce_max(conv_out, 1) if word_sequences is None: return embedded_words if not isinstance(word_sequences, list): word_sequences = [word_sequences] all_embedded = [] for word_idxs in word_sequences: all_embedded.append(tf.nn.embedding_lookup(embedded_words, word_idxs)) return all_embedded
def san_answer_layer(size, encoded_question, question_length, encoded_support, support_length, support2question, answer2support, is_eval, topk=1, max_span_size=10000, num_steps=5, dropout=0.4, **kwargs): question_state = compute_question_state(encoded_question, question_length) question_state = tf.layers.dense(question_state, encoded_support.get_shape()[-1].value, tf.tanh) question_state = tf.gather(question_state, support2question) cell = tf.contrib.rnn.GRUBlockCell(size) all_start_scores = [] all_end_scores = [] support_mask = misc.mask_for_lengths(support_length) for i in range(num_steps): with tf.variable_scope('SAN', reuse=i > 0): question_state = tf.expand_dims(question_state, 1) support_attn = attention.bilinear_attention( question_state, encoded_support, support_length, False, False)[2] question_state = tf.squeeze(question_state, 1) support_attn = tf.squeeze(support_attn, 1) question_state = cell(support_attn, question_state)[0] hidden_start = tf.layers.dense(question_state, size, name="hidden_start") start_scores = tf.einsum('ik,ijk->ij', hidden_start, encoded_support) start_scores = start_scores + support_mask start_probs = segment_softmax(start_scores, support2question) start_states = tf.einsum('ij,ijk->ik', start_probs, encoded_support) start_states = tf.unsorted_segment_sum(start_states, support2question, tf.shape(question_length)[0]) start_states = tf.gather(start_states, support2question) hidden_end = tf.layers.dense(tf.concat([question_state, start_states], 1), size, name="hidden_end") end_scores = tf.einsum('ik,ijk->ij', hidden_end, encoded_support) end_scores = end_scores + support_mask all_start_scores.append(start_scores) all_end_scores.append(end_scores) all_start_scores = tf.stack(all_start_scores) all_end_scores = tf.stack(all_end_scores) dropout_mask = tf.nn.dropout(tf.ones([num_steps, 1, 1]), 1.0 - dropout) all_start_scores = tf.cond(is_eval, lambda: all_start_scores * dropout_mask, lambda: all_start_scores) all_end_scores = tf.cond(is_eval, lambda: all_end_scores * dropout_mask, lambda: all_end_scores) start_scores = tf.reduce_mean(all_start_scores, axis=0) end_scores = tf.reduce_mean(all_end_scores, axis=0) return compute_spans(start_scores, end_scores, answer2support, is_eval, support2question, topk=topk, max_span_size=max_span_size)
def train(): gathered_end_scores = tf.gather(end_scores, answer2support) gathered_start_scores = tf.gather(start_scores, answer2support) if correct_start is not None: # assuming we know the correct start we only consider ends after that left_mask = misc.mask_for_lengths(tf.cast(correct_start, tf.int32), max_support_length, mask_right=False) gathered_end_scores = gathered_end_scores + left_mask predicted_start_pointer = tf.argmax(gathered_start_scores, axis=1, output_type=tf.int32) predicted_end_pointer = tf.argmax(gathered_end_scores, axis=1, output_type=tf.int32) return (start_scores, end_scores, tf.gather(doc_idx_for_support, answer2support), predicted_start_pointer, predicted_end_pointer)
def nli_model(size, num_classes, emb_question, question_length, emb_support, support_length): fused_rnn = tf.contrib.rnn.LSTMBlockFusedCell(size) # [batch, 2*output_dim] -> [batch, num_classes] _, q_states = fused_birnn(fused_rnn, emb_question, sequence_length=question_length, dtype=tf.float32, time_major=False, scope="question_rnn") outputs, _ = fused_birnn(fused_rnn, emb_support, sequence_length=support_length, dtype=tf.float32, initial_state=q_states, time_major=False, scope="support_rnn") # [batch, T, 2 * dim] -> [batch, dim] outputs = tf.concat([outputs[0], outputs[1]], axis=2) hidden = tf.layers.dense(outputs, size, tf.nn.relu, name="hidden") * tf.expand_dims( misc.mask_for_lengths(support_length, max_length=tf.shape(outputs)[1], mask_right=False, value=1.0), 2) hidden = tf.reduce_max(hidden, axis=1) # [batch, dim] -> [batch, num_classes] outputs = tf.layers.dense(hidden, num_classes, name="classification") return outputs
def bidaf_layer(seq1, seq1_length, seq2, seq2_length, seq2_to_seq1=None, **kwargs): """Encodes seq1 conditioned on seq2, e.g., using word-by-word attention.""" attn_scores, attn_probs, seq2_weighted = attention.diagonal_bilinear_attention( seq1, seq2, seq2_length, False, seq2_to_seq1=seq2_to_seq1) attn_scores += tf.expand_dims( mask_for_lengths(seq1_length, tf.shape(attn_scores)[1]), 2) max_seq1 = tf.reduce_max(attn_scores, 2) seq1_attention = tf.nn.softmax(max_seq1, 1) seq1_weighted = tf.einsum('ij,ijk->ik', seq1_attention, seq1) seq1_weighted = tf.expand_dims(seq1_weighted, 1) seq1_weighted = tf.tile(seq1_weighted, [1, tf.shape(seq1)[1], 1]) return tf.concat( [seq2_weighted, seq1 * seq2_weighted, seq1 * seq1_weighted], 2)
def attention_softmax(attn_scores, length=None): if length is not None: attn_scores += misc.mask_for_lengths(length, tf.shape(attn_scores)[2]) return tf.nn.softmax(attn_scores)
def conv_char_embedding_multi_filter( num_chars, filter_sizes, embedding_size, word_chars, word_lengths, word_sequences=None, emb_initializer=tf.random_normal_initializer(0.0, 0.1), projection_size=None, scope=None): """Build convolutional character based embeddings for words with multiple filters. Filter sizes is a list and each the position of each size in the list entry refers to its corresponding conv width. It can also be 0 (i.e., no filter of that conv width). E.g., sizes [4, 0, 7, 8] will create 4 conv filters of width 1, no filter of width 2, 7 of width 3 and 8 of width 4. After the convolution max-pooling over characters is employed for each filter. embedding_size refers to the size of the character embeddings and projection size, if given, to the final size of the embedded characters after a final projection. If it is None, no projection will be applied and the resulting size is the sum of all filters. If word sequences are given, these will be embedded with the newly created embeddings. """ with tf.variable_scope(scope or "char_embeddings"): char_embedding_matrix = \ tf.get_variable("char_embedding_matrix", shape=(num_chars, embedding_size), initializer=emb_initializer, trainable=True) pad = tf.zeros( tf.stack([tf.shape(word_lengths)[0], len(filter_sizes) // 2]), tf.int32) word_chars = tf.concat([pad, word_chars, pad], 1) max_word_length = tf.reduce_max(word_lengths) embedded_chars = tf.nn.embedding_lookup(char_embedding_matrix, tf.cast(word_chars, tf.int32)) conv_mask = tf.expand_dims( misc.mask_for_lengths(word_lengths, max_length=max_word_length), 2) embedded_words = [] for i, size in enumerate(filter_sizes): if size == 0: continue conv_width = i + 1 with tf.variable_scope("conv_%d" % conv_width): # create filter like this to get fan-in and fan-out right for initializers depending on those filter = tf.get_variable("filter", [conv_width * embedding_size, size]) filter_reshaped = tf.reshape( filter, [conv_width, embedding_size, size]) cut = len(filter_sizes) // 2 - conv_width // 2 embedded_chars_conv = embedded_chars[:, cut: -cut, :] if cut else embedded_chars conv_out = tf.nn.conv1d(embedded_chars_conv, filter_reshaped, 1, "VALID") conv_out += conv_mask embedded_words.append(tf.reduce_max(conv_out, 1)) embedded_words = tf.concat(embedded_words, 1) if projection_size is not None: embedded_words = tf.layers.dense(embedded_words, projection_size) if word_sequences is None: return embedded_words if not isinstance(word_sequences, list): word_sequences = [word_sequences] all_embedded = [] for word_idxs in word_sequences: embedded_words = tf.nn.embedding_lookup(embedded_words, word_idxs) all_embedded.append(embedded_words) return all_embedded
def create_output(self, shared_resources, input_tensors): tensors = TensorPortTensors(input_tensors) with tf.variable_scope( "fast_qa", initializer=tf.contrib.layers.xavier_initializer()): # Some helpers batch_size = tf.shape(tensors.question_length)[0] max_question_length = tf.reduce_max(tensors.question_length) support_mask = misc.mask_for_lengths(tensors.support_length) input_size = shared_resources.embeddings.shape[-1] size = shared_resources.config["repr_dim"] with_char_embeddings = shared_resources.config.get( "with_char_embeddings", False) # set shapes for inputs tensors.emb_question.set_shape([None, None, input_size]) tensors.emb_support.set_shape([None, None, input_size]) emb_question = tensors.emb_question emb_support = tensors.emb_support if with_char_embeddings: # compute combined embeddings [char_emb_question, char_emb_support ] = conv_char_embedding(len(shared_resources.char_vocab), size, tensors.word_chars, tensors.word_char_length, [ tensors.question_batch_words, tensors.support_batch_words ]) emb_question = tf.concat([emb_question, char_emb_question], 2) emb_support = tf.concat([emb_support, char_emb_support], 2) input_size += size # set shapes for inputs emb_question.set_shape([None, None, input_size]) emb_support.set_shape([None, None, input_size]) # compute encoder features question_features = tf.ones( tf.stack([batch_size, max_question_length, 2])) v_wiqw = tf.get_variable("v_wiq_w", [1, 1, input_size], initializer=tf.constant_initializer(1.0)) wiq_w = tf.matmul(tf.gather(emb_question * v_wiqw, tensors.support2question), emb_support, adjoint_b=True) wiq_w = wiq_w + tf.expand_dims(support_mask, 1) question_binary_mask = tf.gather( tf.sequence_mask(tensors.question_length, dtype=tf.float32), tensors.support2question) wiq_w = tf.reduce_sum( tf.nn.softmax(wiq_w) * tf.expand_dims(question_binary_mask, 2), [1]) # [B, L , 2] support_features = tf.stack([tensors.word_in_question, wiq_w], 2) # highway layer to allow for interaction between concatenated embeddings if with_char_embeddings: with tf.variable_scope("char_embeddings") as vs: emb_question = tf.layers.dense( emb_question, size, name="embeddings_projection") emb_question = highway_network(emb_question, 1) vs.reuse_variables() emb_support = tf.layers.dense(emb_support, size, name="embeddings_projection") emb_support = highway_network(emb_support, 1) keep_prob = 1.0 - shared_resources.config.get("dropout", 0.0) emb_question, emb_support = tf.cond( tensors.is_eval, lambda: (emb_question, emb_support), lambda: (tf.nn.dropout(emb_question, keep_prob, noise_shape= [1, 1, emb_question.get_shape()[-1].value]), tf.nn.dropout(emb_support, keep_prob, noise_shape= [1, 1, emb_question.get_shape()[-1].value]))) # extend embeddings with features emb_question_ext = tf.concat([emb_question, question_features], 2) emb_support_ext = tf.concat([emb_support, support_features], 2) # encode question and support encoder_type = shared_resources.config.get('encoder', 'lstm').lower() if encoder_type in ['lstm', 'sru', 'gru']: size = size + 2 if encoder_type == 'sru' else size # to allow for use of residual in SRU encoded_question = encoder(emb_question_ext, tensors.question_length, size, module=encoder_type) encoded_support = encoder(emb_support_ext, tensors.support_length, size, module=encoder_type, reuse=True) projection_initializer = tf.constant_initializer( np.concatenate([np.eye(size), np.eye(size)])) encoded_question = tf.layers.dense( encoded_question, size, tf.tanh, use_bias=False, kernel_initializer=projection_initializer, name='projection_q') encoded_support = tf.layers.dense( encoded_support, size, tf.tanh, use_bias=False, kernel_initializer=projection_initializer, name='projection_s') else: raise ValueError( "Only rnn ('lstm', 'sru', 'gru') encoder allowed for FastQA!" ) answer_layer = shared_resources.config.get('answer_layer', 'conditional').lower() topk = tf.get_variable('topk', initializer=shared_resources.config.get( 'topk', 1), dtype=tf.int32, trainable=False) topk_p = tf.placeholder(tf.int32, [], 'beam_size_setter') topk_assign = topk.assign(topk_p) self._topk_assign = lambda k: self.tf_session.run( topk_assign, {topk_p: k}) if answer_layer == 'conditional': start_scores, end_scores, doc_idx, predicted_start_pointer, predicted_end_pointer = \ conditional_answer_layer(size, encoded_question, tensors.question_length, encoded_support, tensors.support_length, tensors.correct_start, tensors.support2question, tensors.answer2support, tensors.is_eval, topk=topk, max_span_size=shared_resources.config.get("max_span_size", 10000)) elif answer_layer == 'conditional_bilinear': start_scores, end_scores, doc_idx, predicted_start_pointer, predicted_end_pointer = \ conditional_answer_layer(size, encoded_question, tensors.question_length, encoded_support, tensors.support_length, tensors.correct_start, tensors.support2question, tensors.answer2support, tensors.is_eval, topk=topk, max_span_size=shared_resources.config.get("max_span_size", 10000), bilinear=True) elif answer_layer == 'bilinear': start_scores, end_scores, doc_idx, predicted_start_pointer, predicted_end_pointer = \ bilinear_answer_layer(size, encoded_question, tensors.question_length, encoded_support, tensors.support_length, tensors.support2question, tensors.answer2support, tensors.is_eval, topk=topk, max_span_size=shared_resources.config.get("max_span_size", 10000)) else: raise ValueError span = tf.stack( [doc_idx, predicted_start_pointer, predicted_end_pointer], 1) return TensorPort.to_mapping(self.output_ports, (start_scores, end_scores, span))
def conditional_answer_layer(size, encoded_question, question_length, encoded_support, support_length, correct_start, support2question, answer2support, is_eval, topk=1, max_span_size=10000, bilinear=False): question_state = compute_question_state(encoded_question, question_length) question_state = tf.gather(question_state, support2question) # Prediction # start if bilinear: hidden_start = tf.layers.dense(question_state, size, name="hidden_start") start_scores = tf.einsum('ik,ijk->ij', hidden_start, encoded_support) else: static_input = tf.concat([tf.expand_dims(question_state, 1) * encoded_support, encoded_support], 2) hidden_start = tf.layers.dense(question_state, size, name="hidden_start_1") hidden_start = tf.layers.dense( static_input, size, use_bias=False, name="hidden_start_2") + tf.expand_dims(hidden_start, 1) start_scores = tf.layers.dense(tf.nn.relu(hidden_start), 1, use_bias=False, name="start_scores") start_scores = tf.squeeze(start_scores, [2]) support_mask = misc.mask_for_lengths(support_length) start_scores = start_scores + support_mask max_support_length = tf.shape(start_scores)[1] _, _, num_doc_per_question = tf.unique_with_counts(support2question) offsets = tf.cumsum(num_doc_per_question, exclusive=True) doc_idx_for_support = tf.range(tf.shape(support2question)[0]) - tf.gather(offsets, support2question) doc_idx, start_pointer = tf.cond( is_eval, lambda: segment_top_k(start_scores, support2question, topk)[:2], lambda: (tf.expand_dims(answer2support, 1), tf.expand_dims(correct_start, 1))) doc_idx_flat = tf.reshape(doc_idx, [-1]) start_pointer = tf.reshape(start_pointer, [-1]) start_state = tf.gather_nd(encoded_support, tf.stack([doc_idx_flat, start_pointer], 1)) start_state.set_shape([None, size]) encoded_support_gathered = tf.gather(encoded_support, doc_idx_flat) question_state = tf.gather(question_state, doc_idx_flat) if bilinear: hidden_end = tf.layers.dense(tf.concat([question_state, start_state], 1), size, name="hidden_end") end_scores = tf.einsum('ik,ijk->ij', hidden_end, encoded_support_gathered) else: end_input = tf.concat([tf.expand_dims(start_state, 1) * encoded_support_gathered, tf.gather(static_input, doc_idx_flat)], 2) hidden_end = tf.layers.dense(tf.concat([question_state, start_state], 1), size, name="hidden_end_1") hidden_end = tf.layers.dense( end_input, size, use_bias=False, name="hidden_end_2") + tf.expand_dims(hidden_end, 1) end_scores = tf.layers.dense(tf.nn.relu(hidden_end), 1, use_bias=False, name="end_scores") end_scores = tf.squeeze(end_scores, [2]) end_scores = end_scores + tf.gather(support_mask, doc_idx_flat) def train(): predicted_end_pointer = tf.argmax(end_scores, axis=1, output_type=tf.int32) return start_scores, end_scores, doc_idx, start_pointer, predicted_end_pointer def eval(): # [num_questions * topk, support_length] left_mask = misc.mask_for_lengths(tf.cast(start_pointer, tf.int32), max_support_length, mask_right=False) right_mask = misc.mask_for_lengths(tf.cast(start_pointer + max_span_size, tf.int32), max_support_length) masked_end_scores = end_scores + left_mask + right_mask predicted_ends = tf.argmax(masked_end_scores, axis=1, output_type=tf.int32) return (start_scores, masked_end_scores, tf.gather(doc_idx_for_support, doc_idx_flat), start_pointer, predicted_ends) return tf.cond(is_eval, eval, train)