def test_decomposable_attention(self): with tf.Graph().as_default(): input1_emb = tf.random_uniform([3, 5, 7]) input1_len = tf.constant([5, 2, 0]) input2_emb = tf.random_uniform([3, 8, 7]) input2_len = tf.constant([8, 6, 1]) output_emb = decatt.decomposable_attention( emb1=input1_emb, len1=input1_len, emb2=input2_emb, len2=input2_len, hidden_size=5, hidden_layers=2, dropout_ratio=0.1, mode=tf.estimator.ModeKeys.TRAIN) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) actual_output_emb = sess.run(output_emb) self.assertAllEqual(actual_output_emb.shape, [3, 5])
def build_model(question_tok_wid, question_lens, context_tok_wid, context_lens, embedding_weights, mode): """Wrapper around for Decomposable Attention model for NQ long answer scoring. Args: question_tok_wid: <int32> [batch_size, question_len] question_lens: <int32> [batch_size] context_tok_wid: <int32> [batch_size, num_context, context_len] context_lens: <int32> [batch_size, num_context] embedding_weights: <float> [vocab_size, embed_dim] mode: One of the keys from tf.estimator.ModeKeys. Returns: context_scores: <float> [batch_size, num_context] """ # <float> [batch_size, question_len, embed_dim] question_emb = tf.nn.embedding_lookup(embedding_weights, question_tok_wid) # <float> [batch_size, num_context, context_len, embed_dim] context_emb = tf.nn.embedding_lookup(embedding_weights, context_tok_wid) question_emb = tf.layers.dense(inputs=question_emb, units=FLAGS.hidden_size, activation=None, name="reduce_emb", reuse=False) context_emb = tf.layers.dense(inputs=context_emb, units=FLAGS.hidden_size, activation=None, name="reduce_emb", reuse=True) batch_size, num_contexts, max_context_len, embed_dim = ( tensor_utils.shape(context_emb)) _, max_question_len, _ = tensor_utils.shape(question_emb) # <float> [batch_size * num_context, context_len, embed_dim] flat_context_emb = tf.reshape(context_emb, [-1, max_context_len, embed_dim]) # <int32> [batch_size * num_context] flat_context_lens = tf.reshape(context_lens, [-1]) # <float> [batch_size * num_context, question_len, embed_dim] question_emb_tiled = tf.tile(tf.expand_dims(question_emb, 1), [1, num_contexts, 1, 1]) flat_question_emb_tiled = tf.reshape(question_emb_tiled, [-1, max_question_len, embed_dim]) # <int32> [batch_size * num_context] question_lens_tiled = tf.tile(tf.expand_dims(question_lens, 1), [1, num_contexts]) flat_question_lens_tiled = tf.reshape(question_lens_tiled, [-1]) # <float> [batch_size * num_context, hidden_size] flat_decatt_emb = decatt.decomposable_attention( emb1=flat_question_emb_tiled, len1=flat_question_lens_tiled, emb2=flat_context_emb, len2=flat_context_lens, hidden_size=FLAGS.hidden_size, hidden_layers=FLAGS.hidden_layers, dropout_ratio=FLAGS.dropout_ratio, mode=mode) # <float> [batch_size, num_context, hidden_size] decatt_emb = tf.reshape(flat_decatt_emb, [batch_size, num_contexts, FLAGS.hidden_size]) weighted_num_overlap, unweighted_num_overlap, pos_embs = ( _get_non_neural_features(question_tok_wid=question_tok_wid, question_lens=question_lens, context_tok_wid=context_tok_wid, context_lens=context_lens)) final_emb = tf.concat( [decatt_emb, weighted_num_overlap, unweighted_num_overlap, pos_embs], -1) # Final linear layer to get score. # <float> [batch_size, num_context] context_scores = tf.layers.dense(inputs=final_emb, units=1, activation=None) context_scores = tf.squeeze(context_scores, -1) return context_scores