示例#1
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""
        tf_logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf_logging.info("    name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        if "next_sentence_labels" in features:
            next_sentence_labels = features["next_sentence_labels"]
        else:
            next_sentence_labels = get_dummy_next_sentence_labels(input_ids)

        if mode == tf.estimator.ModeKeys.PREDICT:
            tf.random.set_seed(0)
            seed = 0
            print("Seed as zero")
        else:
            seed = None

        tf_logging.info("Doing dynamic masking (random)")
        special_tokens = [LABEL_UNK, LABEL_0, LABEL_1, LABEL_2]
        masked_input_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights \
            = random_masking(input_ids,
                             input_mask,
                             train_config.max_predictions_per_seq,
                             MASK_ID,
                             seed,
                             special_tokens)

        masked_input_ids, masked_lm_positions_label, masked_label_ids_label, is_test_inst \
            = get_label_indices(masked_input_ids)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = model_class(
            config=model_config,
            is_training=is_training,
            input_ids=masked_input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=train_config.use_one_hot_embeddings,
        )

        (masked_lm_loss, masked_lm_example_loss,
         masked_lm_log_probs) = get_masked_lm_output_fn(
             model_config, model.get_sequence_output(),
             model.get_embedding_table(), masked_lm_positions, masked_lm_ids,
             masked_lm_weights)

        with tf.compat.v1.variable_scope("label_token"):
            (masked_lm_loss_label, masked_lm_example_loss_label,
             masked_lm_log_probs_label) = get_masked_lm_output_fn(
                 model_config, model.get_sequence_output(),
                 model.get_embedding_table(), masked_lm_positions_label,
                 masked_label_ids_label, is_test_inst)
        (next_sentence_loss, next_sentence_example_loss,
         next_sentence_log_probs) = get_next_sentence_output(
             model_config, model.get_pooled_output(), next_sentence_labels)

        total_loss = masked_lm_loss + masked_lm_loss_label * model_config.ratio

        tvars = tf.compat.v1.trainable_variables()

        initialized_variable_names, initialized_variable_names2, init_fn\
            = align_checkpoint_for_lm(tvars,
                                      train_config.checkpoint_type,
                                      train_config.init_checkpoint,
                                      train_config.second_init_checkpoint,
                                      )

        scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu)

        log_var_assignments(tvars, initialized_variable_names,
                            initialized_variable_names2)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer_from_config(
                total_loss, train_config)
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                training_hooks=[OomReportingHook()],
                scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = (metric_fn, [
                masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
                masked_lm_weights, masked_lm_example_loss_label,
                masked_lm_log_probs_label, masked_label_ids_label, is_test_inst
            ])
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            predictions = {
                "input_ids": input_ids,
                "masked_input_ids": masked_input_ids,
                "masked_lm_ids": masked_lm_ids,
                "masked_lm_example_loss": masked_lm_example_loss,
                "masked_lm_positions": masked_lm_positions,
                "masked_lm_example_loss_label": masked_lm_example_loss_label,
                "masked_lm_log_probs_label": masked_lm_log_probs_label,
                "masked_label_ids_label": masked_label_ids_label,
                "is_test_inst": is_test_inst,
            }
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                predictions=predictions,
                scaffold_fn=scaffold_fn)

        return output_spec
示例#2
0
    def model_fn(features, labels, mode, params):    # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""
        logging.info("*** Features ***")
        for name in sorted(features.keys()):
            logging.info("    name = %s, shape = %s" % (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        d_input_ids = features["d_input_ids"]
        d_input_mask = features["d_input_mask"]
        d_location_ids = features["d_location_ids"]
        next_sentence_labels = features["next_sentence_labels"]

        if dict_run_config.prediction_op == "loss":
            seed = 0
        else:
            seed = None

        if dict_run_config.prediction_op == "loss_fixed_mask" or train_config.fixed_mask:
            masked_input_ids = input_ids
            masked_lm_positions = features["masked_lm_positions"]
            masked_lm_ids = features["masked_lm_ids"]
            masked_lm_weights = tf.ones_like(masked_lm_positions, dtype=tf.float32)
        else:
            masked_input_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights \
                = random_masking(input_ids, input_mask, train_config.max_predictions_per_seq, MASK_ID, seed)

        if dict_run_config.use_d_segment_ids:
            d_segment_ids = features["d_segment_ids"]
        else:
            d_segment_ids = None

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = model_class(
                config=bert_config,
                d_config=dbert_config,
                is_training=is_training,
                input_ids=masked_input_ids,
                input_mask=input_mask,
                d_input_ids=d_input_ids,
                d_input_mask=d_input_mask,
                d_location_ids=d_location_ids,
                use_target_pos_emb=dict_run_config.use_target_pos_emb,
                token_type_ids=segment_ids,
                use_one_hot_embeddings=train_config.use_one_hot_embeddings,
                d_segment_ids=d_segment_ids,
                pool_dict_output=dict_run_config.pool_dict_output,
        )

        (masked_lm_loss,
         masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
                 bert_config, model.get_sequence_output(), model.get_embedding_table(),
                 masked_lm_positions, masked_lm_ids, masked_lm_weights)
        (next_sentence_loss, next_sentence_example_loss,
         next_sentence_log_probs) = get_next_sentence_output(
                 bert_config, model.get_pooled_output(), next_sentence_labels)

        total_loss = masked_lm_loss

        if dict_run_config.train_op == "entry_prediction":
            score_label = features["useful_entry"] # [batch, 1]
            score_label = tf.reshape(score_label, [-1])
            entry_logits = bert_common.dense(2, bert_common.create_initializer(bert_config.initializer_range))\
                (model.get_dict_pooled_output())
            print("entry_logits: ", entry_logits.shape)
            losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=entry_logits, labels=score_label)
            loss = tf.reduce_mean(losses)
            total_loss = loss

        if dict_run_config.train_op == "lookup":
            lookup_idx = features["lookup_idx"]
            lookup_loss, lookup_example_loss, lookup_score = \
                sequence_index_prediction(bert_config, lookup_idx, model.get_sequence_output())

            total_loss += lookup_loss

        tvars = tf.compat.v1.trainable_variables()

        init_vars = {}
        scaffold_fn = None
        if train_config.init_checkpoint:
            if dict_run_config.is_bert_checkpoint:
                map1, map2, init_vars = get_bert_assignment_map_for_dict(tvars, train_config.init_checkpoint)

                def load_fn():
                    tf.compat.v1.train.init_from_checkpoint(train_config.init_checkpoint, map1)
                    tf.compat.v1.train.init_from_checkpoint(train_config.init_checkpoint, map2)
            else:
                map1, init_vars = get_assignment_map_as_is(tvars, train_config.init_checkpoint)

                def load_fn():
                    tf.compat.v1.train.init_from_checkpoint(train_config.init_checkpoint, map1)

            if train_config.use_tpu:
                def tpu_scaffold():
                    load_fn()
                    return tf.compat.v1.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                load_fn()

        logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in init_vars:
                init_string = ", *INIT_FROM_CKPT*"
            logging.info("    name = %s, shape = %s%s", var.name, var.shape, init_string)
        logging.info("Total parameters : %d" % get_param_num())

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            if train_config.gradient_accumulation == 1:
                train_op = optimization.create_optimizer_from_config(total_loss, train_config)
            else:
                logging.info("Using gradient accumulation : %d" % train_config.gradient_accumulation)
                train_op = get_accumulated_optimizer_from_config(total_loss, train_config,
                                                                 tvars, train_config.gradient_accumulation)
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = (metric_fn, [
                    masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
                    masked_lm_weights, next_sentence_example_loss,
                    next_sentence_log_probs, next_sentence_labels
            ])
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    eval_metrics=eval_metrics,
                    scaffold_fn=scaffold_fn)
        else:
            if dict_run_config.prediction_op == "gradient":
                logging.info("Fetching gradient")
                gradient = get_gradients(model, masked_lm_log_probs,
                                         train_config.max_predictions_per_seq, bert_config.vocab_size)
                predictions = {
                        "masked_input_ids": masked_input_ids,
                        #"input_ids": input_ids,
                        "d_input_ids": d_input_ids,
                        "masked_lm_positions": masked_lm_positions,
                        "gradients": gradient,
                }
            elif dict_run_config.prediction_op == "loss" or dict_run_config.prediction_op == "loss_fixed_mask":
                logging.info("Fetching loss")
                predictions = {
                    "masked_lm_example_loss": masked_lm_example_loss,
                }
            else:
                raise Exception("prediction target not specified")

            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    predictions=predictions,
                    scaffold_fn=scaffold_fn)

        return output_spec
示例#3
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        tf_logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf_logging.info("    name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        if "next_sentence_labels" in features:
            next_sentence_labels = features["next_sentence_labels"]
        else:
            next_sentence_labels = get_dummy_next_sentence_labels(input_ids)
        tlm_prefix = "target_task"

        with tf.compat.v1.variable_scope(tlm_prefix):
            priority_score = tf.stop_gradient(priority_model(features))

        priority_score = priority_score * target_model_config.amp
        masked_input_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights\
            = biased_masking(input_ids,
                             input_mask,
                             priority_score,
                             target_model_config.alpha,
                             train_config.max_predictions_per_seq,
                             MASK_ID)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = model_class(
            config=bert_config,
            is_training=is_training,
            input_ids=masked_input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=train_config.use_one_hot_embeddings,
        )

        (masked_lm_loss, masked_lm_example_loss,
         masked_lm_log_probs) = get_masked_lm_output(
             bert_config, model.get_sequence_output(),
             model.get_embedding_table(), masked_lm_positions, masked_lm_ids,
             masked_lm_weights)

        (next_sentence_loss, next_sentence_example_loss,
         next_sentence_log_probs) = get_next_sentence_output(
             bert_config, model.get_pooled_output(), next_sentence_labels)

        total_loss = masked_lm_loss + next_sentence_loss

        all_vars = tf.compat.v1.all_variables()

        tf_logging.info("We assume priority model is from v2")

        if train_config.checkpoint_type == "v2":
            assignment_map, initialized_variable_names = assignment_map_v2_to_v2(
                all_vars, train_config.init_checkpoint)
            assignment_map2, initialized_variable_names2 = get_assignment_map_remap_from_v2(
                all_vars, tlm_prefix, train_config.second_init_checkpoint)
        else:
            assignment_map, assignment_map2, initialized_variable_names \
                                            = get_tlm_assignment_map_v2(all_vars,
                                              tlm_prefix,
                                              train_config.init_checkpoint,
                                              train_config.second_init_checkpoint)
            initialized_variable_names2 = None

        def init_fn():
            if train_config.init_checkpoint:
                tf.compat.v1.train.init_from_checkpoint(
                    train_config.init_checkpoint, assignment_map)
            if train_config.second_init_checkpoint:
                tf.compat.v1.train.init_from_checkpoint(
                    train_config.second_init_checkpoint, assignment_map2)

        scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu)

        tvars = [v for v in all_vars if not v.name.startswith(tlm_prefix)]
        log_var_assignments(tvars, initialized_variable_names,
                            initialized_variable_names2)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer_from_config(
                total_loss, train_config, tvars)
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = (metric_fn, [
                masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
                masked_lm_weights, next_sentence_example_loss,
                next_sentence_log_probs, next_sentence_labels
            ])
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            predictions = {
                "input_ids": input_ids,
                "masked_input_ids": masked_input_ids,
                "priority_score": priority_score,
                "lm_loss1": features["loss1"],
                "lm_loss2": features["loss2"],
            }
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                predictions=predictions,
                scaffold_fn=scaffold_fn)

        return output_spec
示例#4
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""
        tf_logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf_logging.info("    name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        next_sentence_labels = get_dummy_next_sentence_labels(input_ids)
        batch_size, seq_length = get_batch_and_seq_length(input_ids, 2)
        n_trial = seq_length - 20

        masked_input_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights \
            = one_by_one_masking(input_ids, input_mask, MASK_ID, n_trial)
        num_classes = train_config.num_classes
        n_repeat = num_classes * n_trial

        # [ num_classes * n_trial * batch_size, seq_length]
        repeat_masked_input_ids = tf.tile(masked_input_ids, [num_classes, 1])
        repeat_input_mask = tf.tile(input_mask, [n_repeat, 1])
        repeat_segment_ids = tf.tile(segment_ids, [n_repeat, 1])
        masked_lm_positions = tf.tile(masked_lm_positions, [num_classes, 1])
        masked_lm_ids = tf.tile(masked_lm_ids, [num_classes, 1])
        masked_lm_weights = tf.tile(masked_lm_weights, [num_classes, 1])
        next_sentence_labels = tf.tile(next_sentence_labels, [n_repeat, 1])

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        virtual_labels_ids = tf.tile(tf.expand_dims(tf.range(num_classes), 0),
                                     [1, batch_size * n_trial])
        virtual_labels_ids = tf.reshape(virtual_labels_ids, [-1, 1])

        print("repeat_masked_input_ids", repeat_masked_input_ids.shape)
        print("repeat_input_mask", repeat_input_mask.shape)
        print("virtual_labels_ids", virtual_labels_ids.shape)
        model = BertModelWithLabelInner(
            config=model_config,
            is_training=is_training,
            input_ids=repeat_masked_input_ids,
            input_mask=repeat_input_mask,
            token_type_ids=repeat_segment_ids,
            use_one_hot_embeddings=train_config.use_one_hot_embeddings,
            label_ids=virtual_labels_ids,
        )
        (masked_lm_loss, masked_lm_example_loss,
         masked_lm_log_probs) = get_masked_lm_output_fn(
             model_config, model.get_sequence_output(),
             model.get_embedding_table(), masked_lm_positions, masked_lm_ids,
             masked_lm_weights)

        (next_sentence_loss, next_sentence_example_loss,
         next_sentence_log_probs) = get_next_sentence_output(
             model_config, model.get_pooled_output(), next_sentence_labels)

        total_loss = masked_lm_loss

        # loss = -log(prob)
        # TODO compare log prob of each label

        per_case_loss = tf.reshape(masked_lm_example_loss,
                                   [num_classes, -1, batch_size])
        per_label_loss = tf.reduce_sum(per_case_loss, axis=1)
        bias = tf.zeros([3, 1])
        per_label_score = tf.transpose(-per_label_loss + bias, [1, 0])

        tvars = tf.compat.v1.trainable_variables()

        initialized_variable_names, initialized_variable_names2, init_fn\
            = align_checkpoint_for_lm(tvars,
                                      train_config.checkpoint_type,
                                      train_config.init_checkpoint,
                                      train_config.second_init_checkpoint,
                                      )

        scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu)

        log_var_assignments(tvars, initialized_variable_names,
                            initialized_variable_names2)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer_from_config(
                total_loss, train_config)
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                training_hooks=[OomReportingHook()],
                scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = (metric_fn_lm, [
                masked_lm_example_loss,
                masked_lm_log_probs,
                masked_lm_ids,
                masked_lm_weights,
            ])
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            predictions = {"input_ids": input_ids, "logits": per_label_score}
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                predictions=predictions,
                scaffold_fn=scaffold_fn)

        return output_spec