Exemple #1
0
def _add_side_input_features(
    features: Mapping[Text, tf.Tensor],
    model_config: modeling.EtcConfig,
    candidate_ignore_hard_g2l: bool = True,
    query_ignore_hard_g2l: bool = True,
    enable_l2g_linking: bool = True,
) -> Dict[Text, tf.Tensor]:
    """Replaces raw input features with derived ETC side inputs."""

    features = dict(features)
    global_token_type_ids = features["global_token_type_ids"]

    if candidate_ignore_hard_g2l:
        # Have all the candidate global tokens attend to everything in the long
        # even when `use_hard_g2l_mask` is enabled.
        candidate_ignore_hard_g2l_mask = tf.cast(
            tf.equal(global_token_type_ids,
                     multihop_utils.CANDIDATE_GLOBAL_TOKEN_TYPE_ID),
            dtype=global_token_type_ids.dtype)
    else:
        candidate_ignore_hard_g2l_mask = tf.zeros_like(global_token_type_ids)

    if query_ignore_hard_g2l:
        query_ignore_hard_g2l_mask = tf.cast(tf.equal(
            global_token_type_ids,
            multihop_utils.QUESTION_GLOBAL_TOKEN_TYPE_ID),
                                             dtype=global_token_type_ids.dtype)

    else:
        query_ignore_hard_g2l_mask = tf.zeros_like(global_token_type_ids)

    ignore_hard_g2l_mask = (query_ignore_hard_g2l_mask +
                            candidate_ignore_hard_g2l_mask)

    if enable_l2g_linking:
        l2g_linked_ids = features["l2g_linked_ids"]
    else:
        l2g_linked_ids = None

    side_inputs = (multihop_utils.make_global_local_transformer_side_inputs(
        long_paragraph_breakpoints=features["long_paragraph_breakpoints"],
        long_paragraph_ids=features["long_paragraph_ids"],
        long_sentence_ids=features["long_sentence_ids"],
        global_paragraph_breakpoints=features["global_paragraph_breakpoints"],
        local_radius=model_config.local_radius,
        relative_pos_max_distance=model_config.relative_pos_max_distance,
        use_hard_g2l_mask=model_config.use_hard_g2l_mask,
        ignore_hard_g2l_mask=ignore_hard_g2l_mask,
        use_hard_l2g_mask=model_config.use_hard_l2g_mask,
        l2g_linked_ids=l2g_linked_ids))

    features.update(side_inputs.to_dict())
    return features
    def test_make_global_local_transformer_side_inputs(self, use_hard_g2l_mask,
                                                       use_hard_l2g_mask):
        tf.logging.set_verbosity(tf.logging.INFO)
        # Example input:
        # Q_ID is token corresponding to question. P* represents paragraph tokens
        # and S* represents sentence level tokens.
        #
        # A total of 14 global tokens as follows:
        #. 0.    1.   2.   3.   4.   5. 6  7  8   9.  10 11 12    13
        # [CLS] Q_ID Q_ID Q_ID Q_ID  P1 T1 S1 S2  P2  T2 S1 S2  --padding--
        #
        # Long Input:
        # T*, W* represent (Title, Words) WordPieces in HotpotQA context.
        # For example, (T1, W1) correspond to (Title1, Title1Words) and belong
        # to the same sentence in the long input. Hence, there is only one
        # corresponding global token for both of them.
        # Q* represent the question WordPieces.
        #
        # S1, S2 are sentences (each with one WordPiece) of T1 and S3, S4 are
        # sentences (each with one WordPiece) of T2
        # Q1 Q2 Q3 Q4 T1 W1 S1 S2 T2 W2 S3 S4
        #
        # A total of 15 long tokens.
        #
        # Padding with 0s
        long_paragraph_breakpoints = tf.convert_to_tensor(
            [[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0]])

        # Note the difference here - padding with -1s instead.
        long_sentence_ids = tf.convert_to_tensor(
            [[1, 2, 3, 4, 6, 6, 7, 8, 10, 10, 11, 12, -1, -1, -1]])

        # Note the difference here - padding with -1s instead.
        long_paragraph_ids = tf.convert_to_tensor(
            [[-1, -1, -1, -1, 5, 5, 5, 5, 9, 9, 9, 9, -1, -1, -1]])

        # Let's say we want to link 0-th, 4-th, 5-th, 6-th long tokens to the 1st
        # 2nd global tokens.
        l2g_linked_ids = tf.convert_to_tensor(
            [[1, -1, -1, -1, 1, 1, 2, -1, -1, -1, -1, -1, -1, -1, -1]])

        # Padding with 0s
        global_paragraph_breakpoints = tf.convert_to_tensor(
            [[1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0]])

        # Let's say we want the first and the third global tokens to attend to
        # everything in the long (except padding) even if `hard_g2l` is enabled.
        # Note that this tensor will be used / applicable only when
        # `use_hard_g2l_mask` is enabled.
        ignore_hard_g2l_mask = tf.convert_to_tensor(
            [[1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

        # Let's say we want the first long token to attend to everything in the
        # global (except padding) even if `hard_l2g` is enabled. Note that this
        # tensor will be used / applicable only when `use_hard_l2g_mask` is
        # enabled.
        ignore_hard_l2g_mask = tf.convert_to_tensor(
            [[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

        side_inputs = (
            multihop_utils.make_global_local_transformer_side_inputs(
                long_paragraph_breakpoints=long_paragraph_breakpoints,
                long_paragraph_ids=long_paragraph_ids,
                long_sentence_ids=long_sentence_ids,
                global_paragraph_breakpoints=global_paragraph_breakpoints,
                local_radius=4,
                relative_pos_max_distance=2,
                use_hard_g2l_mask=use_hard_g2l_mask,
                ignore_hard_g2l_mask=ignore_hard_g2l_mask,
                use_hard_l2g_mask=use_hard_l2g_mask,
                ignore_hard_l2g_mask=ignore_hard_l2g_mask,
                l2g_linked_ids=l2g_linked_ids))

        self._compare_side_inputs(side_inputs,
                                  use_hard_g2l_mask=use_hard_g2l_mask,
                                  use_hard_l2g_mask=use_hard_l2g_mask)
Exemple #3
0
def build_model(etc_model_config: modeling.EtcConfig,
                features: Dict[str, tf.Tensor], flat_sequence: bool,
                is_training: bool, answer_encoding_method: str, use_tpu: bool,
                use_wordpiece: bool):
    """Build the ETC HotpotQA model."""
    long_token_ids = features["long_token_ids"]
    long_sentence_ids = features["long_sentence_ids"]
    long_paragraph_ids = features["long_paragraph_ids"]
    long_paragraph_breakpoints = features["long_paragraph_breakpoints"]
    long_token_type_ids = features["long_token_type_ids"]
    global_token_ids = features["global_token_ids"]
    global_paragraph_breakpoints = features["global_paragraph_breakpoints"]
    global_token_type_ids = features["global_token_type_ids"]

    model = modeling.EtcModel(config=etc_model_config,
                              is_training=is_training,
                              use_one_hot_relative_embeddings=use_tpu)

    model_inputs = dict(token_ids=long_token_ids,
                        global_token_ids=global_token_ids,
                        segment_ids=long_token_type_ids,
                        global_segment_ids=global_token_type_ids)

    cls_token_id = (generate_tf_examples_lib.
                    SENTENCEPIECE_DEFAULT_GLOBAL_TOKEN_IDS["CLS_TOKEN_ID"])
    if use_wordpiece:
        cls_token_id = (generate_tf_examples_lib.
                        WORDPIECE_DEFAULT_GLOBAL_TOKEN_IDS["CLS_TOKEN_ID"])

    model_inputs.update(
        qa_input_utils.make_global_local_transformer_side_inputs(
            long_paragraph_breakpoints=long_paragraph_breakpoints,
            long_paragraph_ids=long_paragraph_ids,
            long_sentence_ids=long_sentence_ids,
            global_paragraph_breakpoints=global_paragraph_breakpoints,
            local_radius=etc_model_config.local_radius,
            relative_pos_max_distance=etc_model_config.
            relative_pos_max_distance,
            use_hard_g2l_mask=etc_model_config.use_hard_g2l_mask,
            ignore_hard_g2l_mask=tf.cast(tf.equal(global_token_ids,
                                                  cls_token_id),
                                         dtype=long_sentence_ids.dtype),
            flat_sequence=flat_sequence,
            use_hard_l2g_mask=etc_model_config.use_hard_l2g_mask).to_dict(
                exclude_none_values=True))

    long_output, global_output = model(**model_inputs)

    batch_size, long_seq_length, long_hidden_size = tensor_utils.get_shape_list(
        long_output, expected_rank=3)
    _, global_seq_length, global_hidden_size = tensor_utils.get_shape_list(
        global_output, expected_rank=3)

    long_output_matrix = tf.reshape(
        long_output, [batch_size * long_seq_length, long_hidden_size])
    global_output_matrix = tf.reshape(
        global_output, [batch_size * global_seq_length, global_hidden_size])

    # Get the logits for the supporting facts predictions.
    supporting_facts_output_weights = tf.get_variable(
        "supporting_facts_output_weights", [1, global_hidden_size],
        initializer=tf.truncated_normal_initializer(stddev=0.02))
    supporting_facts_output_bias = tf.get_variable(
        "supporting_facts_output_bias", [1],
        initializer=tf.zeros_initializer())
    supporting_facts_logits = tf.matmul(global_output_matrix,
                                        supporting_facts_output_weights,
                                        transpose_b=True)
    supporting_facts_logits = tf.nn.bias_add(supporting_facts_logits,
                                             supporting_facts_output_bias)
    supporting_facts_logits = tf.reshape(supporting_facts_logits,
                                         [batch_size, global_seq_length])

    # Get the logits for the answer type prediction.
    num_answer_types = 3  # SPAN, YES, NO
    answer_type_output_weights = tf.get_variable(
        "answer_type_output_weights", [num_answer_types, global_hidden_size],
        initializer=tf.truncated_normal_initializer(stddev=0.02))
    answer_type_output_bias = tf.get_variable(
        "answer_type_output_bias", [num_answer_types],
        initializer=tf.zeros_initializer())
    answer_type_logits = tf.matmul(global_output[:, 0, :],
                                   answer_type_output_weights,
                                   transpose_b=True)
    answer_type_logits = tf.nn.bias_add(answer_type_logits,
                                        answer_type_output_bias)

    extra_model_losses = model.losses

    if answer_encoding_method == "span":
        # Get the logits for the begin and end indices.
        answer_span_output_weights = tf.get_variable(
            "answer_span_output_weights", [2, long_hidden_size],
            initializer=tf.truncated_normal_initializer(stddev=0.02))
        answer_span_output_bias = tf.get_variable(
            "answer_span_output_bias", [2], initializer=tf.zeros_initializer())
        answer_span_logits = tf.matmul(long_output_matrix,
                                       answer_span_output_weights,
                                       transpose_b=True)
        answer_span_logits = tf.nn.bias_add(answer_span_logits,
                                            answer_span_output_bias)
        answer_span_logits = tf.reshape(answer_span_logits,
                                        [batch_size, long_seq_length, 2])
        answer_span_logits = tf.transpose(answer_span_logits, [2, 0, 1])
        answer_begin_logits, answer_end_logits = tf.unstack(answer_span_logits,
                                                            axis=0)

        return (supporting_facts_logits, (answer_begin_logits,
                                          answer_end_logits),
                answer_type_logits, extra_model_losses)
    else:
        # Get the logits for the answer BIO encodings.
        answer_bio_output_weights = tf.get_variable(
            "answer_bio_output_weights", [3, long_hidden_size],
            initializer=tf.truncated_normal_initializer(stddev=0.02))
        answer_type_output_bias = tf.get_variable(
            "answer_bio_output_bias", [3], initializer=tf.zeros_initializer())
        answer_bio_logits = tf.matmul(long_output_matrix,
                                      answer_bio_output_weights,
                                      transpose_b=True)
        answer_bio_logits = tf.nn.bias_add(answer_bio_logits,
                                           answer_type_output_bias)
        answer_bio_logits = tf.reshape(answer_bio_logits,
                                       [batch_size, long_seq_length, 3])

        return (supporting_facts_logits, answer_bio_logits, answer_type_logits,
                extra_model_losses)