示例#1
0
def input_fn_perspective_passage(flags):
    input_files = get_input_files_from_flags(flags)
    max_seq_length = flags.max_seq_length
    show_input_files(input_files)
    is_training = flags.do_train
    num_cpu_threads = 4

    def input_fn(params):
        """The actual input function."""
        batch_size = params["batch_size"]

        name_to_features = dict({
            "input_ids1":
            tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "input_mask1":
            tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "segment_ids1":
            tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "input_ids2":
            tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "input_mask2":
            tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "segment_ids2":
            tf.io.FixedLenFeature([max_seq_length], tf.int64),
        })
        name_to_features["strict_good"] = tf.io.FixedLenFeature([1], tf.int64)
        name_to_features["strict_bad"] = tf.io.FixedLenFeature([1], tf.int64)
        return format_dataset(name_to_features, batch_size, is_training, flags,
                              input_files, num_cpu_threads)

    return input_fn
示例#2
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    is_training = FLAGS.do_train
    if FLAGS.do_train or FLAGS.do_eval:
        model_fn = model_fn_with_loss(
            config,
            train_config,
            MES_hinge,
        )
        input_fn = input_fn_builder_pairwise(FLAGS.max_d_seq_length, FLAGS)

    else:
        model_fn = model_fn_with_loss(
            config,
            train_config,
            MES_pred,
        )

        input_fn = input_fn_builder_classification_w_data_id2(
            input_files,
            FLAGS.max_seq_length,
            FLAGS,
            is_training,
            num_cpu_threads=4)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    result = run_estimator(model_fn, input_fn)
    return result
示例#3
0
def main(_):
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    input_files = []
    for input_pattern in FLAGS.input_file.split(","):
        input_files.extend(tf.io.gfile.glob(input_pattern))
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)

    ssdr_config = JsonConfig.from_json_file(FLAGS.model_config_file)

    model_fn = model_fn_apr_classification(
        bert_config=bert_config,
        ssdr_config=ssdr_config,
        train_config=train_config,
        dict_run_config=DictRunConfig.from_flags(FLAGS),
    )
    if FLAGS.do_train:
        input_fn = input_fn_builder(input_files=input_files,
                                    max_seq_length=FLAGS.max_seq_length,
                                    is_training=True)
    elif FLAGS.do_eval or FLAGS.do_predict:
        input_fn = input_fn_builder(input_files=input_files,
                                    max_seq_length=FLAGS.max_seq_length,
                                    is_training=False)
    else:
        raise Exception()

    return run_estimator(model_fn, input_fn)
示例#4
0
def main(_):
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    input_files = []
    for input_pattern in FLAGS.input_file.split(","):
        input_files.extend(tf.io.gfile.glob(input_pattern))
    train_config = TrainConfigEx.from_flags(FLAGS)
    model_config = JsonConfig.from_json_file(FLAGS.model_config_file)

    show_input_files(input_files)

    model_fn = loss_diff_prediction_model(
        bert_config=bert_config,
        train_config=train_config,
        model_class=BertModel,
        model_config=model_config,
    )
    if FLAGS.do_train:
        input_fn = input_fn_builder_masked(
            input_files=input_files,
            flags=FLAGS,
            is_training=True)
    elif FLAGS.do_eval or FLAGS.do_predict:
        input_fn = input_fn_builder_masked(
            input_files=input_files,
            flags=FLAGS,
            is_training=False)
    else:
        raise Exception()


    run_estimator(model_fn, input_fn)
示例#5
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    model_fn = model_fn_classification(
        config,
        train_config,
        BertModel,
        special_flags
    )

    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    is_training = FLAGS.do_train
    if FLAGS.do_train or FLAGS.do_eval:
        input_fn = input_fn_builder_classification(input_files, FLAGS.max_seq_length, is_training, FLAGS,
                                                   num_cpu_threads=4,
                                                   repeat_for_eval=False)
    else:
        input_fn = input_fn_builder_classification_w_data_id2(
            input_files,
            FLAGS.max_seq_length,
            FLAGS,
            is_training,
            num_cpu_threads=4)

    result = run_estimator(model_fn, input_fn)
    return result
示例#6
0
def main(_):
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    input_files = flags_wrapper.get_input_files()
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    model_fn = model_fn_explain(
        bert_config=bert_config,
        train_config=train_config,
        logging=tf_logging,
    )
    input_fn = input_fn_from_flags(input_fn_builder_classification, FLAGS)
    r = run_estimator(model_fn, input_fn)
    return r
示例#7
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)

    show_input_files(input_files)

    if FLAGS.do_predict:
        model_fn = model_fn_rank_pred(FLAGS)
        input_fn = input_fn_builder_prediction(
            input_files=input_files, max_seq_length=FLAGS.max_seq_length)
    else:
        assert False

    result = run_estimator(model_fn, input_fn)
    return result
示例#8
0
def run_classification_w_second_input():
    input_files = get_input_files_from_flags(FLAGS)
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    model_fn = model_fn_classification(config, train_config, ME5_2,
                                       special_flags)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    input_fn = input_fn_builder_cppnc_multi_evidence(FLAGS)
    result = run_estimator(model_fn, input_fn)
    return result
示例#9
0
def run_classification_w_second_input():
    input_files = get_input_files_from_flags(FLAGS)
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    model_fn = model_fn_classification(
        bert_config,
        train_config,
    )
    input_fn = input_fn_builder_use_second_input(FLAGS)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    result = run_estimator(model_fn, input_fn)
    return result
示例#10
0
def main(_):
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    input_files = flags_wrapper.get_input_files()
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    model_fn = model_fn_explain(
        bert_config=bert_config,
        train_config=train_config,
        logging=tf_logging,
    )
    is_training = FLAGS.do_train
    input_fn = input_fn_builder(input_files, FLAGS.max_seq_length, is_training)

    r = run_estimator(model_fn, input_fn)
    return r
示例#11
0
def run_w_data_id():
    input_files = get_input_files_from_flags(FLAGS)
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    model_fn = model_fn_classification_weighted_loss(
        bert_config,
        train_config,
    )
    if FLAGS.do_predict:
        tf_logging.addFilter(CounterFilter())
    input_fn = input_fn_builder_classification_w_data_id(
        input_files=input_files, flags=FLAGS, is_training=FLAGS.do_train)
    result = run_estimator(model_fn, input_fn)
    return result
示例#12
0
def main(_):
    tf_logging.info("Token scoring")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_files = get_input_files_from_flags(FLAGS)
    show_input_files(input_files)
    input_fn = input_fn_token_scoring2(input_files, FLAGS, FLAGS.do_train)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")

    model_fn = model_fn_token_scoring(config, train_config)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
示例#13
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    model_fn = model_fn_classification_with_ada(
        config,
        train_config,
    )
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    input_fn = input_fn_builder_ada(input_files, FLAGS, FLAGS.do_train)
    result = run_estimator(model_fn, input_fn)
    return result
示例#14
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    model_fn = model_fn_sensitivity(
        bert_config=bert_config,
        train_config=train_config,
        model_class=BertModel,
        special_flags=special_flags,
    )
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    input_fn = input_fn_builder_use_second_input(FLAGS)
    result = run_estimator(model_fn, input_fn)
    return result
def main(_):
    input_files = get_input_files_from_flags(FLAGS)

    show_input_files(input_files)

    if FLAGS.do_predict:
        model_fn = model_fn_rank_pred(FLAGS)
        input_fn = input_fn_builder_prediction_w_data_id(
            input_files=input_files, max_seq_length=FLAGS.max_seq_length)
    else:
        assert False

    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    result = run_estimator(model_fn, input_fn)
    return result
示例#16
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    model_fn = model_fn_classification(
        bert_config=bert_config,
        train_config=train_config,
        model_class=FreezeEmbedding,
        special_flags=special_flags,
    )

    input_fn = input_fn_builder_classification_w_data_id(
        input_files=input_files, flags=FLAGS, is_training=FLAGS.do_train)

    result = run_estimator(model_fn, input_fn)
    return result
示例#17
0
def run_w_data_id():
    input_files = get_input_files_from_flags(FLAGS)
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    model_fn = model_fn_classification(
        bert_config=bert_config,
        train_config=train_config,
        model_class=BertModel,
        special_flags=special_flags,
    )
    if FLAGS.do_predict:
        tf_logging.addFilter(CounterFilter())
    input_fn = input_fn_builder_classification_w_data_ids_typo(
        input_files=input_files, flags=FLAGS, is_training=FLAGS.do_train)
    result = run_estimator(model_fn, input_fn)
    return result
示例#18
0
def main(_):
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    input_files = []
    for input_pattern in FLAGS.input_file.split(","):
        input_files.extend(tf.io.gfile.glob(input_pattern))
    train_config = TrainConfigEx.from_flags(FLAGS)

    show_input_files(input_files)

    model_fn = model_fn_preserved_dim(
        bert_config=bert_config,
        train_config=train_config,
    )
    if FLAGS.do_predict:
        input_fn = input_fn_builder_unmasked(input_files=input_files,
                                             flags=FLAGS,
                                             is_training=False)
    else:
        raise Exception("Only PREDICT mode is allowed")

    run_estimator(model_fn, input_fn)
示例#19
0
def input_fn_builder_multi_context_classification(max_seq_length, max_context, max_context_length, flags):
    input_files = get_input_files_from_flags(flags)
    show_input_files(input_files)
    is_training = flags.do_train
    num_cpu_threads = 4
    raw_context_len = max_context * max_context_length
    def input_fn(params):
        """The actual input function."""
        batch_size = params["batch_size"]
        name_to_features = {
            "input_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "input_mask": tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "segment_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "context_input_ids": tf.io.FixedLenFeature([raw_context_len], tf.int64),
            "context_input_mask": tf.io.FixedLenFeature([raw_context_len], tf.int64),
            "context_segment_ids": tf.io.FixedLenFeature([raw_context_len], tf.int64),
            "label_ids": tf.io.FixedLenFeature([1], tf.int64),
        }
        return format_dataset(name_to_features, batch_size, is_training, flags, input_files, num_cpu_threads)

    return input_fn
示例#20
0
def main(_):
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    input_files = flags_wrapper.get_input_files()
    train_config = TrainConfigEx.from_flags(FLAGS)

    show_input_files(input_files)

    model_fn = model_fn_tlm_debug(
        bert_config=bert_config,
        train_config=train_config,
        logging=tf_logging,
        model_class=BertModel,
    )
    if FLAGS.do_predict:
        input_fn = input_fn_builder_unmasked(input_files=input_files,
                                             flags=FLAGS,
                                             is_training=False)
    else:
        assert False

    r = run_estimator(model_fn, input_fn)
    return r
示例#21
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")

    def override_prediction_fn(predictions, model):
        predictions['vector'] = model.get_output()
        return predictions

    model_fn = model_fn_classification(
        bert_config=bert_config,
        train_config=train_config,
        model_class=MultiEvidenceUseFirst,
        special_flags=special_flags,
        override_prediction_fn=override_prediction_fn)
    if FLAGS.do_predict:
        tf_logging.addFilter(CounterFilter())
    input_fn = input_fn_builder_use_second_input(FLAGS)
    result = run_estimator(model_fn, input_fn)
    return result
示例#22
0
def main(_):
    tf_logging.addFilter(CounterFilter())
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    input_files = flags_wrapper.get_input_files()
    train_config = TrainConfigEx.from_flags(FLAGS)

    show_input_files(input_files)

    model_fn = model_fn_try_all_loss(
        bert_config=bert_config,
        train_config=train_config,
        logging=tf_logging,
    )
    if FLAGS.do_predict:
        input_fn = input_fn_builder_unmasked(
            input_files=input_files,
            flags=FLAGS,
            is_training=False)
    else:
        assert False

    r = run_estimator(model_fn, input_fn)
    return r
示例#23
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    is_training = FLAGS.do_train
    model_fn = model_fn_binary_classification_loss(
        config,
        train_config,
        MES_const_0_handle,
    )
    input_fn = input_fn_builder_classification(input_files,
                                               FLAGS.max_d_seq_length,
                                               is_training,
                                               FLAGS,
                                               num_cpu_threads=4,
                                               repeat_for_eval=False)

    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    result = run_estimator(model_fn, input_fn)
    return result
示例#24
0
def main(_):
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    input_files = get_input_files_from_flags(FLAGS)
    train_config = TrainConfigEx.from_flags(FLAGS)

    show_input_files(input_files)

    model_fn = loss_diff_prediction_model_online(
        bert_config=bert_config,
        train_config=train_config,
        model_class=BertModel,
    )
    if FLAGS.do_train:
        input_fn = input_fn_builder_unmasked(input_files=input_files,
                                             flags=FLAGS,
                                             is_training=False)
    elif FLAGS.do_eval or FLAGS.do_predict:
        input_fn = input_fn_builder_unmasked(input_files=input_files,
                                             flags=FLAGS,
                                             is_training=False)
    else:
        raise Exception()

    run_estimator(model_fn, input_fn)