def test_reader_inputs(self):
        concat_inputs = orqa_ops.reader_inputs(question_token_ids=[0, 1],
                                               block_token_ids=[[2, 3, 4],
                                                                [5, 6, 0]],
                                               block_lengths=[3, 2],
                                               block_token_map=[[1, 2, 5],
                                                                [1, 3, 4]],
                                               answer_token_ids=[[3, 4],
                                                                 [7, 0]],
                                               answer_lengths=[2, 1],
                                               cls_token_id=10,
                                               sep_token_id=11,
                                               max_sequence_len=10)

        self.assertAllEqual(concat_inputs.token_ids.numpy(),
                            [[10, 0, 1, 11, 2, 3, 4, 11, 0, 0],
                             [10, 0, 1, 11, 5, 6, 11, 0, 0, 0]])
        self.assertAllEqual(
            concat_inputs.mask.numpy(),
            [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])
        self.assertAllEqual(
            concat_inputs.segment_ids.numpy(),
            [[0, 0, 0, 0, 1, 1, 1, 1, 0, 0], [0, 0, 0, 0, 1, 1, 1, 0, 0, 0]])
        self.assertAllEqual(
            concat_inputs.block_mask.numpy(),
            [[0, 0, 0, 0, 1, 1, 1, 0, 0, 0], [0, 0, 0, 0, 1, 1, 0, 0, 0, 0]])
        self.assertAllEqual(concat_inputs.token_map.numpy(),
                            [[-1, -1, -1, -1, 1, 2, 5, -1, -1, -1],
                             [-1, -1, -1, -1, 1, 3, -1, -1, -1, -1]])
        self.assertAllEqual(concat_inputs.gold_starts.numpy(), [[5], [-1]])
        self.assertAllEqual(concat_inputs.gold_ends.numpy(), [[6], [-1]])
Esempio n. 2
0
def read(features, retriever_logits, blocks, mode, params, labels):
    """Do reading."""
    tokenizer, vocab_lookup_table = bert_utils.get_tf_tokenizer(
        params["reader_module_path"])

    orig_blocks = blocks

    (orig_tokens, block_token_map, block_token_ids,
     blocks) = (bert_utils.tokenize_with_original_mapping(blocks, tokenizer))

    question_token_ids = tokenizer.tokenize(
        tf.expand_dims(features["question"], 0))
    question_token_ids = tf.cast(question_token_ids.flat_values, tf.int32)

    orig_tokens = orig_tokens.to_tensor(default_value="")
    block_lengths = tf.cast(block_token_ids.row_lengths(), tf.int32)
    block_token_ids = tf.cast(block_token_ids.to_tensor(), tf.int32)
    block_token_map = tf.cast(block_token_map.to_tensor(), tf.int32)

    answer_token_ids = tokenizer.tokenize(labels).merge_dims(1, 2)
    answer_lengths = tf.cast(answer_token_ids.row_lengths(), tf.int32)
    answer_token_ids = tf.cast(answer_token_ids.to_tensor(), tf.int32)

    cls_token_id = vocab_lookup_table.lookup(tf.constant("[CLS]"))
    sep_token_id = vocab_lookup_table.lookup(tf.constant("[SEP]"))
    concat_inputs = orqa_ops.reader_inputs(
        question_token_ids=question_token_ids,
        block_token_ids=block_token_ids,
        block_lengths=block_lengths,
        block_token_map=block_token_map,
        answer_token_ids=answer_token_ids,
        answer_lengths=answer_lengths,
        cls_token_id=tf.cast(cls_token_id, tf.int32),
        sep_token_id=tf.cast(sep_token_id, tf.int32),
        max_sequence_len=params["reader_seq_len"])

    tf.summary.scalar("reader_nonpad_ratio",
                      tf.reduce_mean(tf.cast(concat_inputs.mask, tf.float32)))

    reader_module = hub.Module(
        params["reader_module_path"],
        tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {},
        trainable=True)

    concat_outputs = reader_module(dict(input_ids=concat_inputs.token_ids,
                                        input_mask=concat_inputs.mask,
                                        segment_ids=concat_inputs.segment_ids),
                                   signature="tokens",
                                   as_dict=True)

    concat_token_emb = concat_outputs["sequence_output"]

    # [num_spans], [num_spans], [reader_beam_size, num_spans]
    candidate_starts, candidate_ends, candidate_mask = span_candidates(
        concat_inputs.block_mask, params["max_span_width"])

    # Score with an MLP to enable start/end interaction:
    # score(s, e) = w·σ(w_s·h_s + w_e·h_e)
    kernel_initializer = tf.truncated_normal_initializer(stddev=0.02)

    # [reader_beam_size, max_sequence_len, span_hidden_size * 2]
    projection = tf.layers.dense(concat_token_emb,
                                 params["span_hidden_size"] * 2,
                                 kernel_initializer=kernel_initializer)

    # [reader_beam_size, max_sequence_len, span_hidden_size]
    start_projection, end_projection = tf.split(projection, 2, -1)

    # [reader_beam_size, num_candidates, span_hidden_size]
    candidate_start_projections = tf.gather(start_projection,
                                            candidate_starts,
                                            axis=1)
    candidate_end_projection = tf.gather(end_projection,
                                         candidate_ends,
                                         axis=1)
    candidate_hidden = candidate_start_projections + candidate_end_projection

    candidate_hidden = tf.nn.relu(candidate_hidden)
    candidate_hidden = tf.keras.layers.LayerNormalization(
        axis=-1)(candidate_hidden)

    # [reader_beam_size, num_candidates, 1]
    reader_logits = tf.layers.dense(candidate_hidden,
                                    1,
                                    kernel_initializer=kernel_initializer)

    # [reader_beam_size, num_candidates]
    reader_logits = tf.squeeze(reader_logits)
    reader_logits += mask_to_score(candidate_mask)
    reader_logits += tf.expand_dims(retriever_logits, -1)

    # [reader_beam_size, num_candidates]
    candidate_orig_starts = tf.gather(params=concat_inputs.token_map,
                                      indices=candidate_starts,
                                      axis=-1)
    candidate_orig_ends = tf.gather(params=concat_inputs.token_map,
                                    indices=candidate_ends,
                                    axis=-1)

    return ReaderOutputs(logits=reader_logits,
                         candidate_starts=candidate_starts,
                         candidate_ends=candidate_ends,
                         candidate_orig_starts=candidate_orig_starts,
                         candidate_orig_ends=candidate_orig_ends,
                         blocks=blocks,
                         orig_blocks=orig_blocks,
                         orig_tokens=orig_tokens,
                         token_ids=concat_inputs.token_ids,
                         gold_starts=concat_inputs.gold_starts,
                         gold_ends=concat_inputs.gold_ends)
Esempio n. 3
0
def read(features, retriever_logits, blocks, mode, params):
    """Do reading."""
    tokenizer, vocab_lookup_table = bert_utils.get_tf_tokenizer(
        params["reader_module_path"])

    (orig_tokens, block_token_map, block_token_ids,
     blocks) = (bert_utils.tokenize_with_original_mapping(blocks, tokenizer))

    # NOTE: we assume that the batch size is 1.
    question_token_ids = tokenizer.tokenize(
        tf.expand_dims(features["question"], 0))
    question_token_ids = tf.cast(question_token_ids.flat_values, tf.int32)

    orig_tokens = orig_tokens.to_tensor(default_value="")
    block_lengths = tf.cast(block_token_ids.row_lengths(), tf.int32)
    block_token_ids = tf.cast(block_token_ids.to_tensor(), tf.int32)
    block_token_map = tf.cast(block_token_map.to_tensor(), tf.int32)

    fake_answer = tf.constant([""])
    answer_token_ids = tokenizer.tokenize(fake_answer).merge_dims(1, 2)
    answer_lengths = tf.cast(answer_token_ids.row_lengths(), tf.int32)
    answer_token_ids = tf.cast(answer_token_ids.to_tensor(), tf.int32)

    cls_token_id = vocab_lookup_table.lookup(tf.constant("[CLS]"))
    sep_token_id = vocab_lookup_table.lookup(tf.constant("[SEP]"))

    concat_inputs = orqa_ops.reader_inputs(
        question_token_ids=question_token_ids,
        block_token_ids=block_token_ids,
        block_lengths=block_lengths,
        block_token_map=block_token_map,
        answer_token_ids=answer_token_ids,
        answer_lengths=answer_lengths,
        cls_token_id=tf.cast(cls_token_id, tf.int32),
        sep_token_id=tf.cast(sep_token_id, tf.int32),
        max_sequence_len=params["reader_seq_len"])

    tf.summary.scalar("reader_nonpad_ratio",
                      tf.reduce_mean(tf.cast(concat_inputs.mask, tf.float32)))

    reader_module = hub.Module(
        params["reader_module_path"],
        tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {},
        trainable=True)

    concat_outputs = reader_module(dict(input_ids=concat_inputs.token_ids,
                                        input_mask=concat_inputs.mask,
                                        segment_ids=concat_inputs.segment_ids),
                                   signature="tokens",
                                   as_dict=True)

    with tf.variable_scope("pooler"):
        # [reader_beam_size, hidden_size]
        first_token_tensor = concat_outputs["pooled_output"]

        kernel_initializer = tf.truncated_normal_initializer(stddev=0.02)

        # [reader_beam_size, num_classes]
        reader_logits = tf.layers.dense(first_token_tensor,
                                        params["num_classes"],
                                        kernel_initializer=kernel_initializer)

    final_logits = reader_logits + tf.expand_dims(retriever_logits, -1)

    return final_logits