Beispiel #1
0
def main(_):
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    input_files = []
    for input_pattern in FLAGS.input_file.split(","):
        input_files.extend(tf.io.gfile.glob(input_pattern))
    train_config = TrainConfigEx.from_flags(FLAGS)
    model_config = JsonConfig.from_json_file(FLAGS.model_config_file)

    show_input_files(input_files)

    model_fn = loss_diff_prediction_model(
        bert_config=bert_config,
        train_config=train_config,
        model_class=BertModel,
        model_config=model_config,
    )
    if FLAGS.do_train:
        input_fn = input_fn_builder_masked(
            input_files=input_files,
            flags=FLAGS,
            is_training=True)
    elif FLAGS.do_eval or FLAGS.do_predict:
        input_fn = input_fn_builder_masked(
            input_files=input_files,
            flags=FLAGS,
            is_training=False)
    else:
        raise Exception()


    run_estimator(model_fn, input_fn)
Beispiel #2
0
def main(_):
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)

    input_files = flags_wrapper.get_input_files()

    input_fn = input_fn_builder_unmasked(input_files, FLAGS, False, True)
    model_fn = model_fn_share_fetch_grad(config, train_config, SharingFetchGradModel)
    run_estimator(model_fn, input_fn)
Beispiel #3
0
def main(_):
    is_training = FLAGS.do_train

    input_files = []
    for input_pattern in FLAGS.input_file.split(","):
        input_files.extend(tf.io.gfile.glob(input_pattern))

    input_fn = input_fn_builder_unmasked(input_files, FLAGS, is_training)
    model_fn = model_fn_builder()

    run_estimator(model_fn, input_fn)
Beispiel #4
0
def main(_):
    tf_logging.info("Train nli_lm_shared")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)

    is_training = FLAGS.do_train
    input_files = flags_wrapper.get_input_files()

    input_fn = input_fn_builder(input_files, FLAGS, is_training, True)
    model_fn = model_fn_nli_lm(config, train_config, SimpleSharingModel)
    run_estimator(model_fn, input_fn)
Beispiel #5
0
def main(_):
    tf_logging.info("Train MLM")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)

    is_training = FLAGS.do_train
    input_files = []
    for input_pattern in FLAGS.input_file.split(","):
        input_files.extend(tf.io.gfile.glob(input_pattern))

    input_fn = input_fn_builder_masked2(input_files, FLAGS, is_training)
    model_fn = model_fn_lm(config, train_config, BertModel)
    run_estimator(model_fn, input_fn)
Beispiel #6
0
def main(_):
    tf_logging.info("Run NLI with BERT but with file that contain alt_emb_ids")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)

    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    train_config = TrainConfigEx.from_flags(FLAGS)

    input_fn = input_fn_builder_alt_emb2_classification(
        input_files, FLAGS, is_training)

    model_fn = model_fn_classification(config, train_config, BertModel)
    run_estimator(model_fn, input_fn)
Beispiel #7
0
def main(_):
    tf_logging.setLevel(logging.INFO)
    if FLAGS.log_debug:
        tf_logging.setLevel(logging.DEBUG)

    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)

    is_training = FLAGS.do_train
    input_fn = input_fn_builder(get_input_files_from_flags(FLAGS), FLAGS,
                                is_training)
    model_fn = mask_lm_as_seq2seq(config, train_config)

    run_estimator(model_fn, input_fn)
Beispiel #8
0
def main(_):
    tf_logging.info("Train MLM  with BERT like")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)

    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)

    input_fn = input_fn_builder_unmasked_alt_emb2(input_files, FLAGS,
                                                  is_training)

    model_fn = model_fn_lm(config, train_config, BertModel,
                           get_masked_lm_output)
    run_estimator(model_fn, input_fn)
Beispiel #9
0
def main(_):
    tf_logging.info("text pair ranking")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_perspective_passage(FLAGS)
    model_fn = text_pair_ranking_pairwise_model(config, train_config, BertModel, "")
    return run_estimator(model_fn, input_fn)
Beispiel #10
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    is_training = FLAGS.do_train
    if FLAGS.do_train or FLAGS.do_eval:
        model_fn = model_fn_with_loss(
            config,
            train_config,
            MES_hinge,
        )
        input_fn = input_fn_builder_pairwise(FLAGS.max_d_seq_length, FLAGS)

    else:
        model_fn = model_fn_with_loss(
            config,
            train_config,
            MES_pred,
        )

        input_fn = input_fn_builder_classification_w_data_id2(
            input_files,
            FLAGS.max_seq_length,
            FLAGS,
            is_training,
            num_cpu_threads=4)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    result = run_estimator(model_fn, input_fn)
    return result
Beispiel #11
0
def main(_):
    tf_logging.info("Train albert")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_from_flags(input_fn_builder_classification, FLAGS)
    model_fn = model_fn_classification(config, train_config, Albert.factory)
    return run_estimator(model_fn, input_fn)
def main(_):
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    input_files = []
    for input_pattern in FLAGS.input_file.split(","):
        input_files.extend(tf.io.gfile.glob(input_pattern))
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)

    ssdr_config = JsonConfig.from_json_file(FLAGS.model_config_file)

    model_fn = model_fn_apr_classification(
        bert_config=bert_config,
        ssdr_config=ssdr_config,
        train_config=train_config,
        dict_run_config=DictRunConfig.from_flags(FLAGS),
    )
    if FLAGS.do_train:
        input_fn = input_fn_builder(input_files=input_files,
                                    max_seq_length=FLAGS.max_seq_length,
                                    is_training=True)
    elif FLAGS.do_eval or FLAGS.do_predict:
        input_fn = input_fn_builder(input_files=input_files,
                                    max_seq_length=FLAGS.max_seq_length,
                                    is_training=False)
    else:
        raise Exception()

    return run_estimator(model_fn, input_fn)
def main(_):
    tf_logging.info("Run generative predictor")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_from_flags(input_fn_builder_classification, FLAGS)
    model_fn = model_fn_generative_predictor(config, train_config)
    return run_estimator(model_fn, input_fn)
Beispiel #14
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    model_fn = model_fn_classification(
        config,
        train_config,
        BertModel,
        special_flags
    )

    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    is_training = FLAGS.do_train
    if FLAGS.do_train or FLAGS.do_eval:
        input_fn = input_fn_builder_classification(input_files, FLAGS.max_seq_length, is_training, FLAGS,
                                                   num_cpu_threads=4,
                                                   repeat_for_eval=False)
    else:
        input_fn = input_fn_builder_classification_w_data_id2(
            input_files,
            FLAGS.max_seq_length,
            FLAGS,
            is_training,
            num_cpu_threads=4)

    result = run_estimator(model_fn, input_fn)
    return result
def main(_):
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_builder(input_files, FLAGS, False)
    model_fn = model_fn_lm(bert_config, train_config, BertModel)
    return run_estimator(model_fn, input_fn)
Beispiel #16
0
def main(_):
    tf_logging.info("TripleBertMasking")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_builder_cppnc_triple(FLAGS)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")

    def override_prediction_fn(predictions, model):
        for key, value in model.get_predictions().items():
            predictions[key] = value
        return predictions

    if FLAGS.modeling == "TripleBertMasking":
        model_class = TripleBertMasking
    elif FLAGS.modeling == "TripleBertWeighted":
        model_class = TripleBertWeighted
    else:
        assert False

    model_fn = model_fn_classification(config, train_config, model_class,
                                       special_flags, override_prediction_fn)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Beispiel #17
0
def main_inner():
    train_config = TrainConfigEx.from_flags(FLAGS)
    model_fn = model_fn_classification(
        train_config=train_config,
    )
    input_fn = input_fn_from_flags(input_fn_builder, FLAGS)
    r = run_estimator(model_fn, input_fn)
    return r
Beispiel #18
0
def main(_):
    tf_logging.info("Train horizon classification")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_from_flags(input_fn_builder_classification, FLAGS)
    model_fn = model_fn_classification(config, train_config,
                                       BertologyFactory(HorizontalAlpha))
    return run_estimator(model_fn, input_fn)
Beispiel #19
0
def main(_):
    tf_logging.info("Run MSMarco")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_builder(input_files, FLAGS.max_seq_length, is_training)
    model_fn = model_fn_classification(config, train_config, BertModel)
    return run_estimator(model_fn, input_fn)
Beispiel #20
0
def main(_):
    tf_logging.info("label_as_token")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_builder_unmasked(input_files, FLAGS, is_training)
    model_fn = model_fn_lm(config, train_config, BertModel)
    return run_estimator(model_fn, input_fn)
Beispiel #21
0
def main(_):
    config = JsonConfig.from_json_file(FLAGS.bert_config_file)
    sero_config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_from_flags(input_fn_builder_classification, FLAGS)
    model_fn = model_fn_classification(config, train_config,
                                       partial(DualSeroBertModel, sero_config),
                                       FLAGS.special_flags.split(","))
    return run_estimator(model_fn, input_fn)
Beispiel #22
0
def main(_):
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)

    input_fn = input_fn_from_flags(input_fn_builder_classification, FLAGS)

    model_fn = model_fn_classification(config, train_config, BertModel)
    r = run_estimator(model_fn, input_fn)
    return r
Beispiel #23
0
def main(_):
    tf_logging.info("Train topic_vector")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_topic_fn(input_files, FLAGS, is_training)
    model_fn = model_fn_lm(config, train_config, TopicVectorBert.factory,
                           get_masked_lm_output, True)
    return run_estimator(model_fn, input_fn)
Beispiel #24
0
def main(_):
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_file = get_input_files_from_flags(FLAGS)

    input_fn = input_fn_builder_classification_w_data_id(input_file,
                                                        FLAGS,
                                                        FLAGS.do_train)
    model_fn = model_fn_pooling_long_things(config, train_config, BertModel)
    return run_estimator(model_fn, input_fn)
Beispiel #25
0
def main(_):
    tf_logging.info("Train albert")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_builder_unmasked(input_files, FLAGS, is_training)
    model_fn = model_fn_lm(config, train_config, Albert.factory,
                           get_masked_lm_output_albert)
    return run_estimator(model_fn, input_fn)
Beispiel #26
0
def main(_):
    set_level_debug()
    tf_logging.info("Train horizon")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_builder_unmasked(input_files, FLAGS, is_training)
    model_fn = model_fn_lm(config, train_config, BertologyFactory(HorizontalAlpha), get_masked_lm_output_albert)
    return run_estimator(model_fn, input_fn)
Beispiel #27
0
def main(_):
    set_level_debug()
    tf_logging.info("Train reshape bert")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_builder_unmasked(input_files, FLAGS, is_training)
    model_fn = model_fn_lm(config, train_config, ReshapeBertModel)
    return run_estimator(model_fn, input_fn)
Beispiel #28
0
def main(_):
    tf_logging.info("Classification with alt loss")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_from_flags(input_fn_builder_classification, FLAGS)

    model_fn = model_fn_classification_with_alt_loss(config, train_config, BertModel)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Beispiel #29
0
def main(_):
    tf_logging.info("Train MLM  with alternative embedding")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)

    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)

    input_fn = input_fn_builder_unmasked_alt_emb(input_files, FLAGS,
                                                 is_training)

    def model_constructor(config, is_training, input_ids, input_mask,
                          token_type_ids, use_one_hot_embeddings, features):
        return EmbeddingReplacer(config, is_training, input_ids, input_mask,
                                 token_type_ids, use_one_hot_embeddings,
                                 features)

    model_fn = model_fn_lm(config, train_config, model_constructor,
                           get_masked_lm_output, True)
    run_estimator(model_fn, input_fn)
Beispiel #30
0
def main(_):
    tf_logging.info("Classification with confidence")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_builder_two_inputs_w_data_id(FLAGS)

    model_fn = model_fn_classification_with_confidence(config, train_config)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)