Пример #1
0
def main(_):
    tf_logging.info("Train albert")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_from_flags(input_fn_builder_classification, FLAGS)
    model_fn = model_fn_classification(config, train_config, Albert.factory)
    return run_estimator(model_fn, input_fn)
Пример #2
0
def main(_):
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    input_files = []
    for input_pattern in FLAGS.input_file.split(","):
        input_files.extend(tf.io.gfile.glob(input_pattern))
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)

    ssdr_config = JsonConfig.from_json_file(FLAGS.model_config_file)

    model_fn = model_fn_apr_classification(
        bert_config=bert_config,
        ssdr_config=ssdr_config,
        train_config=train_config,
        dict_run_config=DictRunConfig.from_flags(FLAGS),
    )
    if FLAGS.do_train:
        input_fn = input_fn_builder(input_files=input_files,
                                    max_seq_length=FLAGS.max_seq_length,
                                    is_training=True)
    elif FLAGS.do_eval or FLAGS.do_predict:
        input_fn = input_fn_builder(input_files=input_files,
                                    max_seq_length=FLAGS.max_seq_length,
                                    is_training=False)
    else:
        raise Exception()

    return run_estimator(model_fn, input_fn)
Пример #3
0
def main(_):
    tf_logging.info("text pair ranking")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_perspective_passage(FLAGS)
    model_fn = text_pair_ranking_pairwise_model(config, train_config, BertModel, "")
    return run_estimator(model_fn, input_fn)
Пример #4
0
def main(_):
    tf_logging.info("TripleBertMasking")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_builder_cppnc_triple(FLAGS)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")

    def override_prediction_fn(predictions, model):
        for key, value in model.get_predictions().items():
            predictions[key] = value
        return predictions

    if FLAGS.modeling == "TripleBertMasking":
        model_class = TripleBertMasking
    elif FLAGS.modeling == "TripleBertWeighted":
        model_class = TripleBertWeighted
    else:
        assert False

    model_fn = model_fn_classification(config, train_config, model_class,
                                       special_flags, override_prediction_fn)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Пример #5
0
def main(_):
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    input_files = []
    for input_pattern in FLAGS.input_file.split(","):
        input_files.extend(tf.io.gfile.glob(input_pattern))
    train_config = TrainConfigEx.from_flags(FLAGS)
    model_config = JsonConfig.from_json_file(FLAGS.model_config_file)

    show_input_files(input_files)

    model_fn = loss_diff_prediction_model(
        bert_config=bert_config,
        train_config=train_config,
        model_class=BertModel,
        model_config=model_config,
    )
    if FLAGS.do_train:
        input_fn = input_fn_builder_masked(
            input_files=input_files,
            flags=FLAGS,
            is_training=True)
    elif FLAGS.do_eval or FLAGS.do_predict:
        input_fn = input_fn_builder_masked(
            input_files=input_files,
            flags=FLAGS,
            is_training=False)
    else:
        raise Exception()


    run_estimator(model_fn, input_fn)
Пример #6
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    is_training = FLAGS.do_train
    if FLAGS.do_train or FLAGS.do_eval:
        model_fn = model_fn_with_loss(
            config,
            train_config,
            MES_hinge,
        )
        input_fn = input_fn_builder_pairwise(FLAGS.max_d_seq_length, FLAGS)

    else:
        model_fn = model_fn_with_loss(
            config,
            train_config,
            MES_pred,
        )

        input_fn = input_fn_builder_classification_w_data_id2(
            input_files,
            FLAGS.max_seq_length,
            FLAGS,
            is_training,
            num_cpu_threads=4)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    result = run_estimator(model_fn, input_fn)
    return result
Пример #7
0
def main(_):
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_builder(input_files, FLAGS, False)
    model_fn = model_fn_lm(bert_config, train_config, BertModel)
    return run_estimator(model_fn, input_fn)
Пример #8
0
def main(_):
    tf_logging.info("Run generative predictor")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_from_flags(input_fn_builder_classification, FLAGS)
    model_fn = model_fn_generative_predictor(config, train_config)
    return run_estimator(model_fn, input_fn)
Пример #9
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    model_fn = model_fn_classification(
        config,
        train_config,
        BertModel,
        special_flags
    )

    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    is_training = FLAGS.do_train
    if FLAGS.do_train or FLAGS.do_eval:
        input_fn = input_fn_builder_classification(input_files, FLAGS.max_seq_length, is_training, FLAGS,
                                                   num_cpu_threads=4,
                                                   repeat_for_eval=False)
    else:
        input_fn = input_fn_builder_classification_w_data_id2(
            input_files,
            FLAGS.max_seq_length,
            FLAGS,
            is_training,
            num_cpu_threads=4)

    result = run_estimator(model_fn, input_fn)
    return result
Пример #10
0
def old_main():
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    input_files = []
    for input_pattern in FLAGS.input_file.split(","):
        input_files.extend(tf.io.gfile.glob(input_pattern))
    train_config = TrainConfigEx.from_flags(FLAGS)
    model_config = JsonConfig.from_json_file(FLAGS.model_config_file)
    tf_logging.addFilter(CounterFilter())

    output_names = []
    input_fn_list = []
    for input_file in input_files:
        name = input_file.split("/")[-1]
        output_name = "disk_output/loss_predictor_predictions/" + name
        input_fn = input_fn_builder_unmasked(input_files=[input_file],
                                             flags=FLAGS,
                                             is_training=False)
        input_fn_list.append(input_fn)
        output_names.append(output_name)

    model_fn = loss_diff_predict_only_model_fn(
        bert_config=bert_config,
        train_config=train_config,
        model_class=BertModel,
        model_config=model_config,
    )
    if FLAGS.do_predict:
        run_estimator_loop(model_fn, input_fn_list, output_names)
    else:
        raise Exception("Only PREDICT mode is allowed")
Пример #11
0
def main_inner():
    train_config = TrainConfigEx.from_flags(FLAGS)
    model_fn = model_fn_classification(
        train_config=train_config,
    )
    input_fn = input_fn_from_flags(input_fn_builder, FLAGS)
    r = run_estimator(model_fn, input_fn)
    return r
Пример #12
0
def main(_):
    tf_logging.info("Train horizon classification")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_from_flags(input_fn_builder_classification, FLAGS)
    model_fn = model_fn_classification(config, train_config,
                                       BertologyFactory(HorizontalAlpha))
    return run_estimator(model_fn, input_fn)
Пример #13
0
def main(_):
    tf_logging.info("Run MSMarco")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_builder(input_files, FLAGS.max_seq_length, is_training)
    model_fn = model_fn_classification(config, train_config, BertModel)
    return run_estimator(model_fn, input_fn)
Пример #14
0
def main(_):
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)

    input_fn = input_fn_from_flags(input_fn_builder_classification, FLAGS)

    model_fn = model_fn_classification(config, train_config, BertModel)
    r = run_estimator(model_fn, input_fn)
    return r
Пример #15
0
def main(_):
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)

    input_files = flags_wrapper.get_input_files()

    input_fn = input_fn_builder_unmasked(input_files, FLAGS, False, True)
    model_fn = model_fn_share_fetch_grad(config, train_config, SharingFetchGradModel)
    run_estimator(model_fn, input_fn)
Пример #16
0
def main(_):
    tf_logging.info("label_as_token")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_builder_unmasked(input_files, FLAGS, is_training)
    model_fn = model_fn_lm(config, train_config, BertModel)
    return run_estimator(model_fn, input_fn)
Пример #17
0
def main(_):
    config = JsonConfig.from_json_file(FLAGS.bert_config_file)
    sero_config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_from_flags(input_fn_builder_classification, FLAGS)
    model_fn = model_fn_classification(config, train_config,
                                       partial(DualSeroBertModel, sero_config),
                                       FLAGS.special_flags.split(","))
    return run_estimator(model_fn, input_fn)
Пример #18
0
def main(_):
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_file = get_input_files_from_flags(FLAGS)

    input_fn = input_fn_builder_classification_w_data_id(input_file,
                                                        FLAGS,
                                                        FLAGS.do_train)
    model_fn = model_fn_pooling_long_things(config, train_config, BertModel)
    return run_estimator(model_fn, input_fn)
Пример #19
0
def main(_):
    set_level_debug()
    tf_logging.info("Train horizon")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_builder_unmasked(input_files, FLAGS, is_training)
    model_fn = model_fn_lm(config, train_config, BertologyFactory(HorizontalAlpha), get_masked_lm_output_albert)
    return run_estimator(model_fn, input_fn)
Пример #20
0
def main(_):
    set_level_debug()
    tf_logging.info("Train reshape bert")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_builder_unmasked(input_files, FLAGS, is_training)
    model_fn = model_fn_lm(config, train_config, ReshapeBertModel)
    return run_estimator(model_fn, input_fn)
Пример #21
0
def main(_):
    tf_logging.info("Train albert")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_builder_unmasked(input_files, FLAGS, is_training)
    model_fn = model_fn_lm(config, train_config, Albert.factory,
                           get_masked_lm_output_albert)
    return run_estimator(model_fn, input_fn)
Пример #22
0
def main(_):
    tf_logging.info("Train topic_vector")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_topic_fn(input_files, FLAGS, is_training)
    model_fn = model_fn_lm(config, train_config, TopicVectorBert.factory,
                           get_masked_lm_output, True)
    return run_estimator(model_fn, input_fn)
Пример #23
0
def model_fn_rank_pred(FLAGS):
    train_config = TrainConfigEx.from_flags(FLAGS)
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    modeling_opt = FLAGS.modeling

    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        tf_logging.info("model_fn_sero_classification")
        """The `model_fn` for TPUEstimator."""
        log_features(features)
        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        # Updated
        model = BertModel(
            config=bert_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=train_config.use_one_hot_embeddings,
        )
        pooled_output = model.get_pooled_output()
        if is_training:
            pooled_output = dropout(pooled_output, 0.1)

        logits = get_prediction_structure(modeling_opt, pooled_output)
        loss = 0

        tvars = tf.compat.v1.trainable_variables()
        assignment_fn = assignment_map.get_bert_assignment_map
        initialized_variable_names, init_fn = get_init_fn(tvars, train_config.init_checkpoint, assignment_fn)
        scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu)
        log_var_assignments(tvars, initialized_variable_names)

        predictions = None
        if modeling_opt == "multi_label_hinge":
            predictions = {
                "input_ids":input_ids,
                "logits":logits,
            }
        else:
            predictions = {
                "input_ids": input_ids,
                "logits": logits,
            }
            useful_inputs = ["data_id", "input_ids2", "data_ids"]
            for input_name in useful_inputs:
                if input_name in features:
                    predictions[input_name] = features[input_name]
        output_spec = rank_predict_estimator_spec(logits, mode, scaffold_fn, predictions)
        return output_spec

    return model_fn
Пример #24
0
def main(_):
    tf_logging.info("Train BertModelWithLabel")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_builder_classification(input_files, FLAGS.max_seq_length, is_training, FLAGS, repeat_for_eval=True)
    model_fn = model_fn_lm(config, train_config, BertModelWithLabel,
                           get_masked_lm_output_fn=get_masked_lm_output,
                           feed_feature=True)
    return run_estimator(model_fn, input_fn)
Пример #25
0
def main(_):
    tf_logging.info("Classification with confidence")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_builder_two_inputs_w_data_id(FLAGS)

    model_fn = model_fn_classification_with_confidence(config, train_config)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Пример #26
0
def main(_):
    tf_logging.info("Train nli_lm_shared")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)

    is_training = FLAGS.do_train
    input_files = flags_wrapper.get_input_files()

    input_fn = input_fn_builder(input_files, FLAGS, is_training, True)
    model_fn = model_fn_nli_lm(config, train_config, SimpleSharingModel)
    run_estimator(model_fn, input_fn)
Пример #27
0
def main(_):
    tf_logging.info("Classification with alt loss")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_from_flags(input_fn_builder_classification, FLAGS)

    model_fn = model_fn_classification_with_alt_loss(config, train_config, BertModel)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Пример #28
0
def main(_):
    tf_logging.info("Regression with weigth")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_builder_regression(get_input_files_from_flags(FLAGS),
                                           FLAGS, FLAGS.do_train)

    model_fn = model_fn_regression(config, train_config)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Пример #29
0
def main(_):
    tf_logging.info("Manual combiner")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_builder_two_inputs_w_data_id(FLAGS)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    model_fn = model_fn_classification_manual_combiner(config, train_config)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Пример #30
0
def main(_):
    tf_logging.info("QCK with rel")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_builder_two_inputs_w_rel(FLAGS)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    model_fn = model_fn_classification(config, train_config,
                                       DualBertTwoInputWRel, special_flags)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)