def test_decomposable_attention(self):
   with tf.Graph().as_default():
     input1_emb = tf.random_uniform([3, 5, 7])
     input1_len = tf.constant([5, 2, 0])
     input2_emb = tf.random_uniform([3, 8, 7])
     input2_len = tf.constant([8, 6, 1])
     output_emb = decatt.decomposable_attention(
         emb1=input1_emb,
         len1=input1_len,
         emb2=input2_emb,
         len2=input2_len,
         hidden_size=5,
         hidden_layers=2,
         dropout_ratio=0.1,
         mode=tf.estimator.ModeKeys.TRAIN)
     with tf.Session() as sess:
       sess.run(tf.global_variables_initializer())
       actual_output_emb = sess.run(output_emb)
     self.assertAllEqual(actual_output_emb.shape, [3, 5])
Ejemplo n.º 2
0
def build_model(question_tok_wid, question_lens, context_tok_wid, context_lens,
                embedding_weights, mode):
    """Wrapper around for Decomposable Attention model for NQ long answer scoring.

  Args:
    question_tok_wid: <int32> [batch_size, question_len]
    question_lens: <int32> [batch_size]
    context_tok_wid: <int32> [batch_size, num_context, context_len]
    context_lens: <int32> [batch_size, num_context]
    embedding_weights: <float> [vocab_size, embed_dim]
    mode: One of the keys from tf.estimator.ModeKeys.

  Returns:
    context_scores: <float> [batch_size, num_context]
  """
    # <float> [batch_size, question_len, embed_dim]
    question_emb = tf.nn.embedding_lookup(embedding_weights, question_tok_wid)
    # <float> [batch_size, num_context, context_len, embed_dim]
    context_emb = tf.nn.embedding_lookup(embedding_weights, context_tok_wid)

    question_emb = tf.layers.dense(inputs=question_emb,
                                   units=FLAGS.hidden_size,
                                   activation=None,
                                   name="reduce_emb",
                                   reuse=False)

    context_emb = tf.layers.dense(inputs=context_emb,
                                  units=FLAGS.hidden_size,
                                  activation=None,
                                  name="reduce_emb",
                                  reuse=True)

    batch_size, num_contexts, max_context_len, embed_dim = (
        tensor_utils.shape(context_emb))
    _, max_question_len, _ = tensor_utils.shape(question_emb)

    # <float> [batch_size * num_context, context_len, embed_dim]
    flat_context_emb = tf.reshape(context_emb,
                                  [-1, max_context_len, embed_dim])

    # <int32> [batch_size * num_context]
    flat_context_lens = tf.reshape(context_lens, [-1])

    # <float> [batch_size * num_context, question_len, embed_dim]
    question_emb_tiled = tf.tile(tf.expand_dims(question_emb, 1),
                                 [1, num_contexts, 1, 1])
    flat_question_emb_tiled = tf.reshape(question_emb_tiled,
                                         [-1, max_question_len, embed_dim])

    # <int32> [batch_size * num_context]
    question_lens_tiled = tf.tile(tf.expand_dims(question_lens, 1),
                                  [1, num_contexts])
    flat_question_lens_tiled = tf.reshape(question_lens_tiled, [-1])

    # <float> [batch_size * num_context, hidden_size]
    flat_decatt_emb = decatt.decomposable_attention(
        emb1=flat_question_emb_tiled,
        len1=flat_question_lens_tiled,
        emb2=flat_context_emb,
        len2=flat_context_lens,
        hidden_size=FLAGS.hidden_size,
        hidden_layers=FLAGS.hidden_layers,
        dropout_ratio=FLAGS.dropout_ratio,
        mode=mode)

    # <float> [batch_size, num_context, hidden_size]
    decatt_emb = tf.reshape(flat_decatt_emb,
                            [batch_size, num_contexts, FLAGS.hidden_size])

    weighted_num_overlap, unweighted_num_overlap, pos_embs = (
        _get_non_neural_features(question_tok_wid=question_tok_wid,
                                 question_lens=question_lens,
                                 context_tok_wid=context_tok_wid,
                                 context_lens=context_lens))

    final_emb = tf.concat(
        [decatt_emb, weighted_num_overlap, unweighted_num_overlap, pos_embs],
        -1)

    # Final linear layer to get score.
    # <float> [batch_size, num_context]
    context_scores = tf.layers.dense(inputs=final_emb,
                                     units=1,
                                     activation=None)
    context_scores = tf.squeeze(context_scores, -1)

    return context_scores