Exemplo n.º 1
0
def main(_):
    tf_logging.info("TripleBertMasking")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_builder_cppnc_triple(FLAGS)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")

    def override_prediction_fn(predictions, model):
        for key, value in model.get_predictions().items():
            predictions[key] = value
        return predictions

    if FLAGS.modeling == "TripleBertMasking":
        model_class = TripleBertMasking
    elif FLAGS.modeling == "TripleBertWeighted":
        model_class = TripleBertWeighted
    else:
        assert False

    model_fn = model_fn_classification(config, train_config, model_class,
                                       special_flags, override_prediction_fn)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Exemplo n.º 2
0
def main(_):
    tf_logging.info("Train albert")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_from_flags(input_fn_builder_classification, FLAGS)
    model_fn = model_fn_classification(config, train_config, Albert.factory)
    return run_estimator(model_fn, input_fn)
Exemplo n.º 3
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    model_fn = model_fn_classification(
        config,
        train_config,
        BertModel,
        special_flags
    )

    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    is_training = FLAGS.do_train
    if FLAGS.do_train or FLAGS.do_eval:
        input_fn = input_fn_builder_classification(input_files, FLAGS.max_seq_length, is_training, FLAGS,
                                                   num_cpu_threads=4,
                                                   repeat_for_eval=False)
    else:
        input_fn = input_fn_builder_classification_w_data_id2(
            input_files,
            FLAGS.max_seq_length,
            FLAGS,
            is_training,
            num_cpu_threads=4)

    result = run_estimator(model_fn, input_fn)
    return result
Exemplo n.º 4
0
def main(_):
    tf_logging.info("Train horizon classification")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_from_flags(input_fn_builder_classification, FLAGS)
    model_fn = model_fn_classification(config, train_config,
                                       BertologyFactory(HorizontalAlpha))
    return run_estimator(model_fn, input_fn)
Exemplo n.º 5
0
def main(_):
    config = JsonConfig.from_json_file(FLAGS.bert_config_file)
    sero_config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_from_flags(input_fn_builder_classification, FLAGS)
    model_fn = model_fn_classification(config, train_config,
                                       partial(DualSeroBertModel, sero_config),
                                       FLAGS.special_flags.split(","))
    return run_estimator(model_fn, input_fn)
Exemplo n.º 6
0
def main(_):
    tf_logging.info("Run MSMarco")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    input_fn = input_fn_builder(input_files, FLAGS.max_seq_length, is_training)
    model_fn = model_fn_classification(config, train_config, BertModel)
    return run_estimator(model_fn, input_fn)
Exemplo n.º 7
0
def main(_):
    tf_logging.info("DualBertTwoInputModelEx")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_builder_two_inputs_w_data_id(FLAGS)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    model_fn = model_fn_classification(config, train_config, DualBertTwoInputModelEx, special_flags)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Exemplo n.º 8
0
def main(_):
    tf_logging.info("Run NLI with BERT but with file that contain alt_emb_ids")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)

    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)
    train_config = TrainConfigEx.from_flags(FLAGS)

    input_fn = input_fn_builder_alt_emb2_classification(
        input_files, FLAGS, is_training)

    model_fn = model_fn_classification(config, train_config, BertModel)
    run_estimator(model_fn, input_fn)
Exemplo n.º 9
0
def main(_):
    tf_logging.info("QCK with ME7")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_builder_cppnc_multi_evidence(FLAGS)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    model_fn = model_fn_classification(config, train_config, ME7,
                                       special_flags)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Exemplo n.º 10
0
def main(_):
    tf_logging.info("ThreeInput QCK")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_builder_dual_bert_double_length_input(FLAGS)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    model_fn = model_fn_classification(config, train_config,
                                       DualBertTwoInputWithDoubleInputLength,
                                       special_flags)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Exemplo n.º 11
0
def run_classification_w_second_input():
    input_files = get_input_files_from_flags(FLAGS)
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    model_fn = model_fn_classification(config, train_config, ME5_2,
                                       special_flags)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    input_fn = input_fn_builder_cppnc_multi_evidence(FLAGS)
    result = run_estimator(model_fn, input_fn)
    return result
def run_classification_w_second_input():
    input_files = get_input_files_from_flags(FLAGS)
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    model_fn = model_fn_classification(
        bert_config=bert_config,
        train_config=train_config,
        model_class=BertModel,
        special_flags=special_flags,
    )
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    input_fn = input_fn_builder_use_second_input(FLAGS)
    result = run_estimator(model_fn, input_fn)
    return result
Exemplo n.º 13
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    model_fn = model_fn_classification(
        bert_config=bert_config,
        train_config=train_config,
        model_class=FreezeEmbedding,
        special_flags=special_flags,
    )

    input_fn = input_fn_builder_classification_w_data_id(
        input_files=input_files, flags=FLAGS, is_training=FLAGS.do_train)

    result = run_estimator(model_fn, input_fn)
    return result
Exemplo n.º 14
0
def run_w_data_id():
    input_files = get_input_files_from_flags(FLAGS)
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    model_fn = model_fn_classification(
        bert_config=bert_config,
        train_config=train_config,
        model_class=BertModel,
        special_flags=special_flags,
    )
    if FLAGS.do_predict:
        tf_logging.addFilter(CounterFilter())
    input_fn = input_fn_builder_classification_w_data_ids_typo(
        input_files=input_files, flags=FLAGS, is_training=FLAGS.do_train)
    result = run_estimator(model_fn, input_fn)
    return result
Exemplo n.º 15
0
def main_inner(model_class=None):
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)

    if model_class is None:
        model_class = BertModel

    special_flags = FLAGS.special_flags.split(",")
    model_fn = model_fn_classification(
        bert_config=bert_config,
        train_config=train_config,
        model_class=model_class,
        special_flags=special_flags,
    )

    input_fn = input_fn_from_flags(input_fn_builder, FLAGS)
    r = run_estimator(model_fn, input_fn)
    return r
Exemplo n.º 16
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")

    def override_prediction_fn(predictions, model):
        predictions['vector'] = model.get_output()
        return predictions

    model_fn = model_fn_classification(
        bert_config=bert_config,
        train_config=train_config,
        model_class=MultiEvidenceUseFirst,
        special_flags=special_flags,
        override_prediction_fn=override_prediction_fn)
    if FLAGS.do_predict:
        tf_logging.addFilter(CounterFilter())
    input_fn = input_fn_builder_use_second_input(FLAGS)
    result = run_estimator(model_fn, input_fn)
    return result
Exemplo n.º 17
0
def main(_):
    tf_logging.info("Train MLM  with alternative embedding2")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)

    is_training = FLAGS.do_train
    input_files = get_input_files_from_flags(FLAGS)

    input_fn = input_fn_builder_alt_emb2_classification(
        input_files, FLAGS, is_training)

    def model_constructor(config, is_training, input_ids, input_mask,
                          token_type_ids, use_one_hot_embeddings, features):
        return EmbeddingReplacer2(config, is_training, input_ids, input_mask,
                                  token_type_ids, use_one_hot_embeddings,
                                  features)

    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")

    model_fn = model_fn_classification(config, train_config, model_constructor,
                                       special_flags)
    run_estimator(model_fn, input_fn)
Exemplo n.º 18
0
def main(_):
    tf_logging.info("DualBertTwoInputModel simple prediction")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_builder_two_inputs_w_data_id(FLAGS)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")

    def override_prediction_fn(predictions, model):
        predictions.pop('input_ids', None)
        try:
            predictions.pop('input_ids2', None)
        except KeyError:
            pass
        return predictions

    model_fn = model_fn_classification(config, train_config,
                                       DualBertTwoInputModel, special_flags,
                                       override_prediction_fn)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Exemplo n.º 19
0
def run_classification_task(model_class):
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_from_flags(input_fn_builder_classification, FLAGS)
    model_fn = model_fn_classification(config, train_config, model_class, FLAGS.special_flags.split(","))
    return run_estimator(model_fn, input_fn)