Exemple #1
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    model_fn = model_fn_classification(
        config,
        train_config,
        BertModel,
        special_flags
    )

    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    is_training = FLAGS.do_train
    if FLAGS.do_train or FLAGS.do_eval:
        input_fn = input_fn_builder_classification(input_files, FLAGS.max_seq_length, is_training, FLAGS,
                                                   num_cpu_threads=4,
                                                   repeat_for_eval=False)
    else:
        input_fn = input_fn_builder_classification_w_data_id2(
            input_files,
            FLAGS.max_seq_length,
            FLAGS,
            is_training,
            num_cpu_threads=4)

    result = run_estimator(model_fn, input_fn)
    return result
Exemple #2
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    is_training = FLAGS.do_train
    if FLAGS.do_train or FLAGS.do_eval:
        model_fn = model_fn_with_loss(
            config,
            train_config,
            MES_hinge,
        )
        input_fn = input_fn_builder_pairwise(FLAGS.max_d_seq_length, FLAGS)

    else:
        model_fn = model_fn_with_loss(
            config,
            train_config,
            MES_pred,
        )

        input_fn = input_fn_builder_classification_w_data_id2(
            input_files,
            FLAGS.max_seq_length,
            FLAGS,
            is_training,
            num_cpu_threads=4)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    result = run_estimator(model_fn, input_fn)
    return result
def main(_):
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_builder(input_files, FLAGS, False)
    model_fn = model_fn_lm(bert_config, train_config, BertModel)
    return run_estimator(model_fn, input_fn)
Exemple #4
0
def input_fn_perspective_passage(flags):
    input_files = get_input_files_from_flags(flags)
    max_seq_length = flags.max_seq_length
    show_input_files(input_files)
    is_training = flags.do_train
    num_cpu_threads = 4

    def input_fn(params):
        """The actual input function."""
        batch_size = params["batch_size"]

        name_to_features = dict({
            "input_ids1":
            tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "input_mask1":
            tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "segment_ids1":
            tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "input_ids2":
            tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "input_mask2":
            tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "segment_ids2":
            tf.io.FixedLenFeature([max_seq_length], tf.int64),
        })
        name_to_features["strict_good"] = tf.io.FixedLenFeature([1], tf.int64)
        name_to_features["strict_bad"] = tf.io.FixedLenFeature([1], tf.int64)
        return format_dataset(name_to_features, batch_size, is_training, flags,
                              input_files, num_cpu_threads)

    return input_fn
Exemple #5
0
def main(_):
    tf_logging.info("Run MSMarco")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_builder(input_files, FLAGS.max_seq_length, is_training)
    model_fn = model_fn_classification(config, train_config, BertModel)
    return run_estimator(model_fn, input_fn)
Exemple #6
0
def main(_):
    tf_logging.info("label_as_token")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_builder_unmasked(input_files, FLAGS, is_training)
    model_fn = model_fn_lm(config, train_config, BertModel)
    return run_estimator(model_fn, input_fn)
Exemple #7
0
def main(_):
    set_level_debug()
    tf_logging.info("Train horizon")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_builder_unmasked(input_files, FLAGS, is_training)
    model_fn = model_fn_lm(config, train_config, BertologyFactory(HorizontalAlpha), get_masked_lm_output_albert)
    return run_estimator(model_fn, input_fn)
Exemple #8
0
def main(_):
    tf_logging.info("Train albert")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_builder_unmasked(input_files, FLAGS, is_training)
    model_fn = model_fn_lm(config, train_config, Albert.factory,
                           get_masked_lm_output_albert)
    return run_estimator(model_fn, input_fn)
Exemple #9
0
def main(_):
    tf_logging.info("Train topic_vector")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_topic_fn(input_files, FLAGS, is_training)
    model_fn = model_fn_lm(config, train_config, TopicVectorBert.factory,
                           get_masked_lm_output, True)
    return run_estimator(model_fn, input_fn)
Exemple #10
0
def main(_):
    set_level_debug()
    tf_logging.info("Train reshape bert")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_builder_unmasked(input_files, FLAGS, is_training)
    model_fn = model_fn_lm(config, train_config, ReshapeBertModel)
    return run_estimator(model_fn, input_fn)
Exemple #11
0
def main(_):
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_file = get_input_files_from_flags(FLAGS)

    input_fn = input_fn_builder_classification_w_data_id(input_file,
                                                        FLAGS,
                                                        FLAGS.do_train)
    model_fn = model_fn_pooling_long_things(config, train_config, BertModel)
    return run_estimator(model_fn, input_fn)
Exemple #12
0
def main(_):
    tf_logging.info("Train BertModelWithLabel")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_builder_classification(input_files, FLAGS.max_seq_length, is_training, FLAGS, repeat_for_eval=True)
    model_fn = model_fn_lm(config, train_config, BertModelWithLabel,
                           get_masked_lm_output_fn=get_masked_lm_output,
                           feed_feature=True)
    return run_estimator(model_fn, input_fn)
Exemple #13
0
def main(_):
    tf_logging.info("Regression with weigth")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_builder_regression(get_input_files_from_flags(FLAGS),
                                           FLAGS, FLAGS.do_train)

    model_fn = model_fn_regression(config, train_config)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Exemple #14
0
def main(_):
    tf_logging.info("Run NLI with BERT but with file that contain alt_emb_ids")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)

    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    train_config = TrainConfigEx.from_flags(FLAGS)

    input_fn = input_fn_builder_alt_emb2_classification(
        input_files, FLAGS, is_training)

    model_fn = model_fn_classification(config, train_config, BertModel)
    run_estimator(model_fn, input_fn)
Exemple #15
0
def run_classification_w_second_input():
    input_files = get_input_files_from_flags(FLAGS)
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    model_fn = model_fn_classification(config, train_config, ME5_2,
                                       special_flags)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    input_fn = input_fn_builder_cppnc_multi_evidence(FLAGS)
    result = run_estimator(model_fn, input_fn)
    return result
Exemple #16
0
def main(_):
    tf_logging.setLevel(logging.INFO)
    if FLAGS.log_debug:
        tf_logging.setLevel(logging.DEBUG)

    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)

    is_training = FLAGS.do_train
    input_fn = input_fn_builder(get_input_files_from_flags(FLAGS), FLAGS,
                                is_training)
    model_fn = mask_lm_as_seq2seq(config, train_config)

    run_estimator(model_fn, input_fn)
Exemple #17
0
def main(_):
    tf_logging.info("Train MLM  with BERT like")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)

    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)

    input_fn = input_fn_builder_unmasked_alt_emb2(input_files, FLAGS,
                                                  is_training)

    model_fn = model_fn_lm(config, train_config, BertModel,
                           get_masked_lm_output)
    run_estimator(model_fn, input_fn)
Exemple #18
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)

    show_input_files(input_files)

    if FLAGS.do_predict:
        model_fn = model_fn_rank_pred(FLAGS)
        input_fn = input_fn_builder_prediction(
            input_files=input_files, max_seq_length=FLAGS.max_seq_length)
    else:
        assert False

    result = run_estimator(model_fn, input_fn)
    return result
Exemple #19
0
def run_classification_w_second_input():
    input_files = get_input_files_from_flags(FLAGS)
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    model_fn = model_fn_classification(
        bert_config,
        train_config,
    )
    input_fn = input_fn_builder_use_second_input(FLAGS)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    result = run_estimator(model_fn, input_fn)
    return result
Exemple #20
0
def run_w_data_id():
    input_files = get_input_files_from_flags(FLAGS)
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    model_fn = model_fn_classification_weighted_loss(
        bert_config,
        train_config,
    )
    if FLAGS.do_predict:
        tf_logging.addFilter(CounterFilter())
    input_fn = input_fn_builder_classification_w_data_id(
        input_files=input_files, flags=FLAGS, is_training=FLAGS.do_train)
    result = run_estimator(model_fn, input_fn)
    return result
Exemple #21
0
def main(_):
    tf_logging.info("Token scoring")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_files = get_input_files_from_flags(FLAGS)
    show_input_files(input_files)
    input_fn = input_fn_token_scoring2(input_files, FLAGS, FLAGS.do_train)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")

    model_fn = model_fn_token_scoring(config, train_config)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Exemple #22
0
def main(_):
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_file = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_builder_prediction(input_file,
                                           config.total_sequence_length)

    if FLAGS.modeling == "epsilon":
        model_class = SeroEpsilon
    elif FLAGS.modeling == "zeta":
        model_class = SeroZeta
    else:
        assert False
    model_fn = model_fn_sero_ranking_predict(config, train_config, model_class)
    return run_estimator(model_fn, input_fn)
Exemple #23
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    model_fn = model_fn_classification_with_ada(
        config,
        train_config,
    )
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    input_fn = input_fn_builder_ada(input_files, FLAGS, FLAGS.do_train)
    result = run_estimator(model_fn, input_fn)
    return result
Exemple #24
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    model_fn = model_fn_sensitivity(
        bert_config=bert_config,
        train_config=train_config,
        model_class=BertModel,
        special_flags=special_flags,
    )
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    input_fn = input_fn_builder_use_second_input(FLAGS)
    result = run_estimator(model_fn, input_fn)
    return result
def main(_):
    input_files = get_input_files_from_flags(FLAGS)

    show_input_files(input_files)

    if FLAGS.do_predict:
        model_fn = model_fn_rank_pred(FLAGS)
        input_fn = input_fn_builder_prediction_w_data_id(
            input_files=input_files, max_seq_length=FLAGS.max_seq_length)
    else:
        assert False

    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    result = run_estimator(model_fn, input_fn)
    return result
Exemple #26
0
def run_w_data_id():
    input_files = get_input_files_from_flags(FLAGS)
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    model_fn = model_fn_classification(
        bert_config=bert_config,
        train_config=train_config,
        model_class=BertModel,
        special_flags=special_flags,
    )
    if FLAGS.do_predict:
        tf_logging.addFilter(CounterFilter())
    input_fn = input_fn_builder_classification_w_data_ids_typo(
        input_files=input_files, flags=FLAGS, is_training=FLAGS.do_train)
    result = run_estimator(model_fn, input_fn)
    return result
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    model_fn = model_fn_classification(
        bert_config=bert_config,
        train_config=train_config,
        model_class=FreezeEmbedding,
        special_flags=special_flags,
    )

    input_fn = input_fn_builder_classification_w_data_id(
        input_files=input_files, flags=FLAGS, is_training=FLAGS.do_train)

    result = run_estimator(model_fn, input_fn)
    return result
Exemple #28
0
def main_inner(model_class=None):
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)

    if model_class is None:
        model_class = BertModel

    special_flags = FLAGS.special_flags.split(",")
    model_fn = model_fn_classification(
        bert_config=bert_config,
        train_config=train_config,
        model_class=model_class,
        special_flags=special_flags,
    )

    input_fn = input_fn_builder_prediction(get_input_files_from_flags(FLAGS),
                                           FLAGS.max_seq_length)
    r = run_estimator(model_fn, input_fn)
    return r
Exemple #29
0
def main(_):
    tf_logging.info("Train MLM  with alternative embedding")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)

    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)

    input_fn = input_fn_builder_unmasked_alt_emb(input_files, FLAGS,
                                                 is_training)

    def model_constructor(config, is_training, input_ids, input_mask,
                          token_type_ids, use_one_hot_embeddings, features):
        return EmbeddingReplacer(config, is_training, input_ids, input_mask,
                                 token_type_ids, use_one_hot_embeddings,
                                 features)

    model_fn = model_fn_lm(config, train_config, model_constructor,
                           get_masked_lm_output, True)
    run_estimator(model_fn, input_fn)
Exemple #30
0
def input_fn_builder_multi_context_classification(max_seq_length, max_context, max_context_length, flags):
    input_files = get_input_files_from_flags(flags)
    show_input_files(input_files)
    is_training = flags.do_train
    num_cpu_threads = 4
    raw_context_len = max_context * max_context_length
    def input_fn(params):
        """The actual input function."""
        batch_size = params["batch_size"]
        name_to_features = {
            "input_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "input_mask": tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "segment_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "context_input_ids": tf.io.FixedLenFeature([raw_context_len], tf.int64),
            "context_input_mask": tf.io.FixedLenFeature([raw_context_len], tf.int64),
            "context_segment_ids": tf.io.FixedLenFeature([raw_context_len], tf.int64),
            "label_ids": tf.io.FixedLenFeature([1], tf.int64),
        }
        return format_dataset(name_to_features, batch_size, is_training, flags, input_files, num_cpu_threads)

    return input_fn