Esempio n. 1
0
def build_model(etc_model_config: modeling.EtcConfig,
                features: Dict[str, tf.Tensor], flat_sequence: bool,
                is_training: bool, answer_encoding_method: str, use_tpu: bool,
                use_wordpiece: bool):
    """Build the ETC HotpotQA model."""
    long_token_ids = features["long_token_ids"]
    long_sentence_ids = features["long_sentence_ids"]
    long_paragraph_ids = features["long_paragraph_ids"]
    long_paragraph_breakpoints = features["long_paragraph_breakpoints"]
    long_token_type_ids = features["long_token_type_ids"]
    global_token_ids = features["global_token_ids"]
    global_paragraph_breakpoints = features["global_paragraph_breakpoints"]
    global_token_type_ids = features["global_token_type_ids"]

    model = modeling.EtcModel(config=etc_model_config,
                              is_training=is_training,
                              use_one_hot_relative_embeddings=use_tpu)

    model_inputs = dict(token_ids=long_token_ids,
                        global_token_ids=global_token_ids,
                        segment_ids=long_token_type_ids,
                        global_segment_ids=global_token_type_ids)

    cls_token_id = (generate_tf_examples_lib.
                    SENTENCEPIECE_DEFAULT_GLOBAL_TOKEN_IDS["CLS_TOKEN_ID"])
    if use_wordpiece:
        cls_token_id = (generate_tf_examples_lib.
                        WORDPIECE_DEFAULT_GLOBAL_TOKEN_IDS["CLS_TOKEN_ID"])

    model_inputs.update(
        qa_input_utils.make_global_local_transformer_side_inputs(
            long_paragraph_breakpoints=long_paragraph_breakpoints,
            long_paragraph_ids=long_paragraph_ids,
            long_sentence_ids=long_sentence_ids,
            global_paragraph_breakpoints=global_paragraph_breakpoints,
            local_radius=etc_model_config.local_radius,
            relative_pos_max_distance=etc_model_config.
            relative_pos_max_distance,
            use_hard_g2l_mask=etc_model_config.use_hard_g2l_mask,
            ignore_hard_g2l_mask=tf.cast(tf.equal(global_token_ids,
                                                  cls_token_id),
                                         dtype=long_sentence_ids.dtype),
            flat_sequence=flat_sequence,
            use_hard_l2g_mask=etc_model_config.use_hard_l2g_mask).to_dict(
                exclude_none_values=True))

    long_output, global_output = model(**model_inputs)

    batch_size, long_seq_length, long_hidden_size = tensor_utils.get_shape_list(
        long_output, expected_rank=3)
    _, global_seq_length, global_hidden_size = tensor_utils.get_shape_list(
        global_output, expected_rank=3)

    long_output_matrix = tf.reshape(
        long_output, [batch_size * long_seq_length, long_hidden_size])
    global_output_matrix = tf.reshape(
        global_output, [batch_size * global_seq_length, global_hidden_size])

    # Get the logits for the supporting facts predictions.
    supporting_facts_output_weights = tf.get_variable(
        "supporting_facts_output_weights", [1, global_hidden_size],
        initializer=tf.truncated_normal_initializer(stddev=0.02))
    supporting_facts_output_bias = tf.get_variable(
        "supporting_facts_output_bias", [1],
        initializer=tf.zeros_initializer())
    supporting_facts_logits = tf.matmul(global_output_matrix,
                                        supporting_facts_output_weights,
                                        transpose_b=True)
    supporting_facts_logits = tf.nn.bias_add(supporting_facts_logits,
                                             supporting_facts_output_bias)
    supporting_facts_logits = tf.reshape(supporting_facts_logits,
                                         [batch_size, global_seq_length])

    # Get the logits for the answer type prediction.
    num_answer_types = 3  # SPAN, YES, NO
    answer_type_output_weights = tf.get_variable(
        "answer_type_output_weights", [num_answer_types, global_hidden_size],
        initializer=tf.truncated_normal_initializer(stddev=0.02))
    answer_type_output_bias = tf.get_variable(
        "answer_type_output_bias", [num_answer_types],
        initializer=tf.zeros_initializer())
    answer_type_logits = tf.matmul(global_output[:, 0, :],
                                   answer_type_output_weights,
                                   transpose_b=True)
    answer_type_logits = tf.nn.bias_add(answer_type_logits,
                                        answer_type_output_bias)

    extra_model_losses = model.losses

    if answer_encoding_method == "span":
        # Get the logits for the begin and end indices.
        answer_span_output_weights = tf.get_variable(
            "answer_span_output_weights", [2, long_hidden_size],
            initializer=tf.truncated_normal_initializer(stddev=0.02))
        answer_span_output_bias = tf.get_variable(
            "answer_span_output_bias", [2], initializer=tf.zeros_initializer())
        answer_span_logits = tf.matmul(long_output_matrix,
                                       answer_span_output_weights,
                                       transpose_b=True)
        answer_span_logits = tf.nn.bias_add(answer_span_logits,
                                            answer_span_output_bias)
        answer_span_logits = tf.reshape(answer_span_logits,
                                        [batch_size, long_seq_length, 2])
        answer_span_logits = tf.transpose(answer_span_logits, [2, 0, 1])
        answer_begin_logits, answer_end_logits = tf.unstack(answer_span_logits,
                                                            axis=0)

        return (supporting_facts_logits, (answer_begin_logits,
                                          answer_end_logits),
                answer_type_logits, extra_model_losses)
    else:
        # Get the logits for the answer BIO encodings.
        answer_bio_output_weights = tf.get_variable(
            "answer_bio_output_weights", [3, long_hidden_size],
            initializer=tf.truncated_normal_initializer(stddev=0.02))
        answer_type_output_bias = tf.get_variable(
            "answer_bio_output_bias", [3], initializer=tf.zeros_initializer())
        answer_bio_logits = tf.matmul(long_output_matrix,
                                      answer_bio_output_weights,
                                      transpose_b=True)
        answer_bio_logits = tf.nn.bias_add(answer_bio_logits,
                                           answer_type_output_bias)
        answer_bio_logits = tf.reshape(answer_bio_logits,
                                       [batch_size, long_seq_length, 3])

        return (supporting_facts_logits, answer_bio_logits, answer_type_logits,
                extra_model_losses)
Esempio n. 2
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s", name,
                            features[name].shape)

        long_token_ids = features["long_token_ids"]
        long_token_type_ids = features["long_token_type_ids"]
        global_token_ids = features["global_token_ids"]
        global_token_type_ids = features["global_token_type_ids"]

        model_inputs = dict(token_ids=long_token_ids,
                            global_token_ids=global_token_ids,
                            global_segment_ids=global_token_type_ids,
                            segment_ids=long_token_type_ids)

        for field in attr.fields(input_utils.GlobalLocalTransformerSideInputs):
            model_inputs[field.name] = features[field.name]

        labels = tf.cast(features["label_id"], dtype=tf.int32)

        is_real_example = None
        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(labels), dtype=tf.float32)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        model = modeling.EtcModel(
            config=model_config,
            is_training=is_training,
            use_one_hot_embeddings=use_one_hot_embeddings,
            use_one_hot_relative_embeddings=use_tpu)

        _, global_output = model(**model_inputs)
        (total_loss, per_example_loss, logits) = (process_model_output(
            model_config, mode, global_output, global_token_type_ids, labels,
            is_real_example, add_final_layer, label_smoothing))
        tvars = tf.trainable_variables()

        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = input_utils.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(
                total_loss, learning_rate, num_train_steps, num_warmup_steps,
                use_tpu, optimizer, poly_power, start_warmup_step,
                learning_rate_schedule, weight_decay_rate)

            metrics_dict = metric_fn(per_example_loss=per_example_loss,
                                     logits=logits,
                                     labels=labels,
                                     is_real_example=is_real_example,
                                     is_train=True)

            host_inputs = {
                "global_step":
                tf.expand_dims(tf.train.get_or_create_global_step(), 0),
            }

            host_inputs.update({
                metric_name: tf.expand_dims(metric_tensor, 0)
                for metric_name, metric_tensor in metrics_dict.items()
            })

            host_call = (functools.partial(record_summary_host_fn,
                                           metrics_dir=os.path.join(
                                               model_dir, "train_metrics"),
                                           steps_per_summary=50), host_inputs)

            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn,
                host_call=host_call)

        elif mode == tf.estimator.ModeKeys.EVAL:
            metric_fn_tensors = dict(per_example_loss=per_example_loss,
                                     logits=logits,
                                     labels=labels,
                                     is_real_example=is_real_example)

            eval_metrics = (functools.partial(metric_fn, is_train=False),
                            metric_fn_tensors)
            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.PREDICT:
            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                predictions={
                    "logits": logits,
                    # Wrap in `tf.identity` to avoid b/130501786.
                    "label_ids": tf.identity(labels),
                },
                scaffold_fn=scaffold_fn)
        else:
            raise ValueError("Unexpected mode {} encountered".format(mode))

        return output_spec
Esempio n. 3
0
def build_model(etc_model_config, features, is_training, flags):
    """Build an ETC model."""
    token_ids = features["token_ids"]
    global_token_ids = features["global_token_ids"]

    model = modeling.EtcModel(config=etc_model_config,
                              is_training=is_training,
                              use_one_hot_relative_embeddings=flags.use_tpu)

    model_inputs = dict(token_ids=token_ids, global_token_ids=global_token_ids)
    for field in attr.fields(input_utils.GlobalLocalTransformerSideInputs):
        if field.name in features:
            model_inputs[field.name] = features[field.name]

    # Get the logits for the start and end predictions.
    l_final_hidden, _ = model(**model_inputs)

    l_final_hidden_shape = tensor_utils.get_shape_list(l_final_hidden,
                                                       expected_rank=3)

    batch_size = l_final_hidden_shape[0]
    l_seq_length = l_final_hidden_shape[1]
    hidden_size = l_final_hidden_shape[2]

    num_answer_types = 5  # NULL, YES, NO, LONG, SHORT

    # We add a dense layer to the long output:
    l_output_weights = tf.get_variable(
        "cls/nq/long_output_weights", [4, hidden_size],
        initializer=tf.truncated_normal_initializer(stddev=0.02))
    l_output_bias = tf.get_variable("cls/nq/long_output_bias", [4],
                                    initializer=tf.zeros_initializer())
    l_final_hidden_matrix = tf.reshape(
        l_final_hidden, [batch_size * l_seq_length, hidden_size])
    l_logits = tf.matmul(l_final_hidden_matrix,
                         l_output_weights,
                         transpose_b=True)
    l_logits = tf.nn.bias_add(l_logits, l_output_bias)
    l_logits = tf.reshape(l_logits, [batch_size, l_seq_length, 4])

    if flags.mask_long_output:
        # Mask out invalid SA/LA start/end positions:
        # 1) find the SEP and CLS tokens:
        long_sep = tf.cast(tf.equal(token_ids, flags.sep_tok_id), tf.int32)
        long_not_sep = 1 - long_sep
        long_cls = tf.cast(tf.equal(token_ids, flags.cls_tok_id), tf.int32)

        # 2) accum sum the SEPs, and the only possible answers are those with sum
        #    equal to 1 (except SEPs) and the CLS position
        l_mask = tf.cast(tf.equal(tf.cumsum(long_sep, axis=-1), 1), tf.int32)
        l_mask = 1 - ((l_mask * long_not_sep) + long_cls)

        # 3) apply the mask to the logits
        l_mask = tf.expand_dims(tf.cast(l_mask, tf.float32) * -10E8, 2)
        l_logits = tf.math.add(l_logits, l_mask)

    # Get the logits for the answer type prediction.
    answer_type_output_layer = l_final_hidden[:, 0, :]
    answer_type_hidden_size = answer_type_output_layer.shape[-1].value

    answer_type_output_weights = tf.get_variable(
        "answer_type_output_weights",
        [num_answer_types, answer_type_hidden_size],
        initializer=tf.truncated_normal_initializer(stddev=0.02))

    answer_type_output_bias = tf.get_variable(
        "answer_type_output_bias", [num_answer_types],
        initializer=tf.zeros_initializer())

    answer_type_logits = tf.matmul(answer_type_output_layer,
                                   answer_type_output_weights,
                                   transpose_b=True)
    answer_type_logits = tf.nn.bias_add(answer_type_logits,
                                        answer_type_output_bias)

    extra_model_losses = model.losses

    l_logits = tf.transpose(l_logits, [2, 0, 1])
    l_unstacked_logits = tf.unstack(l_logits, axis=0)
    return ([l_unstacked_logits[i]
             for i in range(4)], answer_type_logits, extra_model_losses)
def _build_model(model_config, features, is_training, flags):
  """Build an ETC model for OpenKP."""

  global_embedding_adder = None
  long_embedding_adder = None

  # Create `global_embedding_adder` if using visual features.
  if flags.use_visual_features_in_global or flags.use_visual_features_in_long:
    global_embedding_adder = _create_global_visual_feature_embeddings(
        model_config, features, flags)

  if flags.use_visual_features_in_long:
    # Create `long_embedding_adder` based on `global_embedding_adder`
    long_embedding_adder = gather_global_embeddings_to_long(
        global_embedding_adder, features['long_vdom_idx'])

  if not flags.use_visual_features_in_global:
    global_embedding_adder = None

  model = modeling.EtcModel(
      config=model_config,
      is_training=is_training,
      use_one_hot_relative_embeddings=flags.use_tpu)

  model_inputs = dict(
      token_ids=features['long_token_ids'],
      global_token_ids=features['global_token_ids'],
      long_embedding_adder=long_embedding_adder,
      global_embedding_adder=global_embedding_adder)
  for field in attr.fields(input_utils.GlobalLocalTransformerSideInputs):
    model_inputs[field.name] = features[field.name]

  long_output, _ = model(**model_inputs)

  word_embeddings_unnormalized = batch_segment_sum_embeddings(
      long_embeddings=long_output,
      long_word_idx=features['long_word_idx'],
      long_input_mask=features['long_input_mask'])
  word_emb_layer_norm = tf.keras.layers.LayerNormalization(
      axis=-1, epsilon=1e-12, name='word_emb_layer_norm')
  word_embeddings = word_emb_layer_norm(word_embeddings_unnormalized)

  ngram_logit_list = []
  for i in range(flags.kp_max_length):
    conv = tf.keras.layers.Conv1D(
        filters=model_config.hidden_size,
        kernel_size=i + 1,
        padding='valid',
        activation=tensor_utils.get_activation('gelu'),
        kernel_initializer=tf.keras.initializers.TruncatedNormal(
            stddev=0.02 / math.sqrt(i + 1)),
        name=f'{i + 1}gram_conv')
    layer_norm = tf.keras.layers.LayerNormalization(
        axis=-1, epsilon=1e-12, name=f'{i + 1}gram_layer_norm')

    logit_dense = tf.keras.layers.Dense(
        units=1,
        activation=None,
        use_bias=False,
        kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
        name=f'logit_dense{i}')
    # [batch_size, long_max_length - i]
    unpadded_logits = tf.squeeze(
        logit_dense(layer_norm(conv(word_embeddings))), axis=-1)

    # Pad to the right to get back to `long_max_length`.
    padded_logits = tf.pad(unpadded_logits, paddings=[[0, 0], [0, i]])

    # Padding logits should be ignored, so we make a large negative mask adder
    # for them.
    shifted_word_mask = tf.cast(
        tensor_utils.shift_elements_right(
            features['long_word_input_mask'], axis=-1, amount=-i),
        dtype=padded_logits.dtype)
    mask_adder = -10000.0 * (1.0 - shifted_word_mask)

    ngram_logit_list.append(padded_logits * shifted_word_mask + mask_adder)

  # [batch_size, kp_max_length, long_max_length]
  ngram_logits = tf.stack(ngram_logit_list, axis=1)

  extra_model_losses = model.losses

  return ngram_logits, extra_model_losses