Ejemplo n.º 1
0
def old_main():
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    input_files = []
    for input_pattern in FLAGS.input_file.split(","):
        input_files.extend(tf.io.gfile.glob(input_pattern))
    train_config = TrainConfigEx.from_flags(FLAGS)
    model_config = JsonConfig.from_json_file(FLAGS.model_config_file)
    tf_logging.addFilter(CounterFilter())

    output_names = []
    input_fn_list = []
    for input_file in input_files:
        name = input_file.split("/")[-1]
        output_name = "disk_output/loss_predictor_predictions/" + name
        input_fn = input_fn_builder_unmasked(input_files=[input_file],
                                             flags=FLAGS,
                                             is_training=False)
        input_fn_list.append(input_fn)
        output_names.append(output_name)

    model_fn = loss_diff_predict_only_model_fn(
        bert_config=bert_config,
        train_config=train_config,
        model_class=BertModel,
        model_config=model_config,
    )
    if FLAGS.do_predict:
        run_estimator_loop(model_fn, input_fn_list, output_names)
    else:
        raise Exception("Only PREDICT mode is allowed")
Ejemplo n.º 2
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    is_training = FLAGS.do_train
    if FLAGS.do_train or FLAGS.do_eval:
        model_fn = model_fn_with_loss(
            config,
            train_config,
            MES_hinge,
        )
        input_fn = input_fn_builder_pairwise(FLAGS.max_d_seq_length, FLAGS)

    else:
        model_fn = model_fn_with_loss(
            config,
            train_config,
            MES_pred,
        )

        input_fn = input_fn_builder_classification_w_data_id2(
            input_files,
            FLAGS.max_seq_length,
            FLAGS,
            is_training,
            num_cpu_threads=4)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    result = run_estimator(model_fn, input_fn)
    return result
Ejemplo n.º 3
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    model_fn = model_fn_classification(
        config,
        train_config,
        BertModel,
        special_flags
    )

    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    is_training = FLAGS.do_train
    if FLAGS.do_train or FLAGS.do_eval:
        input_fn = input_fn_builder_classification(input_files, FLAGS.max_seq_length, is_training, FLAGS,
                                                   num_cpu_threads=4,
                                                   repeat_for_eval=False)
    else:
        input_fn = input_fn_builder_classification_w_data_id2(
            input_files,
            FLAGS.max_seq_length,
            FLAGS,
            is_training,
            num_cpu_threads=4)

    result = run_estimator(model_fn, input_fn)
    return result
Ejemplo n.º 4
0
def main(_):
    tf_logging.info("TripleBertMasking")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_builder_cppnc_triple(FLAGS)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")

    def override_prediction_fn(predictions, model):
        for key, value in model.get_predictions().items():
            predictions[key] = value
        return predictions

    if FLAGS.modeling == "TripleBertMasking":
        model_class = TripleBertMasking
    elif FLAGS.modeling == "TripleBertWeighted":
        model_class = TripleBertWeighted
    else:
        assert False

    model_fn = model_fn_classification(config, train_config, model_class,
                                       special_flags, override_prediction_fn)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Ejemplo n.º 5
0
def main(_):
    tf_logging.info("Classification with confidence")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_builder_two_inputs_w_data_id(FLAGS)

    model_fn = model_fn_classification_with_confidence(config, train_config)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Ejemplo n.º 6
0
def main(_):
    tf_logging.info("Classification with alt loss")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_from_flags(input_fn_builder_classification, FLAGS)

    model_fn = model_fn_classification_with_alt_loss(config, train_config, BertModel)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Ejemplo n.º 7
0
def main(_):
    tf_logging.info("Regression with weigth")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_builder_regression(get_input_files_from_flags(FLAGS),
                                           FLAGS, FLAGS.do_train)

    model_fn = model_fn_regression(config, train_config)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Ejemplo n.º 8
0
def main(_):
    tf_logging.info("Manual combiner")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_builder_two_inputs_w_data_id(FLAGS)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    model_fn = model_fn_classification_manual_combiner(config, train_config)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Ejemplo n.º 9
0
def main(_):
    tf_logging.info("QCK with ME7")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_builder_cppnc_multi_evidence(FLAGS)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    model_fn = model_fn_classification(config, train_config, ME7,
                                       special_flags)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Ejemplo n.º 10
0
def main(_):
    tf_logging.info("Multi-evidence for QK")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)

    input_fn = input_fn_builder_vector_ck(FLAGS, config)
    special_flags = []

    model_fn = vector_combining_model(config, train_config, special_flags)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Ejemplo n.º 11
0
def main(_):
    tf_logging.info("QCK with rel")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_builder_two_inputs_w_rel(FLAGS)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    model_fn = model_fn_classification(config, train_config,
                                       DualBertTwoInputWRel, special_flags)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Ejemplo n.º 12
0
def main(_):
    tf_logging.info("DualBertTwoInputWithDoubleInputLength")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_builder_dual_bert_double_length_input(FLAGS)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    model_fn = model_fn_classification(config, train_config,
                                       DualBertTwoInputWithDoubleInputLength,
                                       special_flags)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Ejemplo n.º 13
0
def run_classification_w_second_input():
    input_files = get_input_files_from_flags(FLAGS)
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    model_fn = model_fn_classification(config, train_config, ME5_2,
                                       special_flags)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    input_fn = input_fn_builder_cppnc_multi_evidence(FLAGS)
    result = run_estimator(model_fn, input_fn)
    return result
Ejemplo n.º 14
0
def run_classification_w_second_input():
    input_files = get_input_files_from_flags(FLAGS)
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    model_fn = model_fn_classification(
        bert_config,
        train_config,
    )
    input_fn = input_fn_builder_use_second_input(FLAGS)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    result = run_estimator(model_fn, input_fn)
    return result
Ejemplo n.º 15
0
def run_w_data_id():
    input_files = get_input_files_from_flags(FLAGS)
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    model_fn = model_fn_classification_weighted_loss(
        bert_config,
        train_config,
    )
    if FLAGS.do_predict:
        tf_logging.addFilter(CounterFilter())
    input_fn = input_fn_builder_classification_w_data_id(
        input_files=input_files, flags=FLAGS, is_training=FLAGS.do_train)
    result = run_estimator(model_fn, input_fn)
    return result
Ejemplo n.º 16
0
def main(_):
    tf_logging.info("Token scoring")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_files = get_input_files_from_flags(FLAGS)
    show_input_files(input_files)
    input_fn = input_fn_token_scoring2(input_files, FLAGS, FLAGS.do_train)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")

    model_fn = model_fn_token_scoring(config, train_config)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Ejemplo n.º 17
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    model_fn = model_fn_classification_with_ada(
        config,
        train_config,
    )
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    input_fn = input_fn_builder_ada(input_files, FLAGS, FLAGS.do_train)
    result = run_estimator(model_fn, input_fn)
    return result
Ejemplo n.º 18
0
    def __init__(self, out_dir, input_gs_dir):
        self.output_dir = out_dir
        self.input_gs_dir = input_gs_dir
        bert_config = modeling.BertConfig.from_json_file(
            FLAGS.bert_config_file)
        train_config = TrainConfigEx.from_flags(FLAGS)
        model_config = JsonConfig.from_json_file(FLAGS.model_config_file)
        tf_logging.addFilter(CounterFilter())

        model_fn = loss_diff_predict_only_model_fn(
            bert_config=bert_config,
            train_config=train_config,
            model_class=BertModel,
            model_config=model_config,
        )

        tf.io.gfile.makedirs(FLAGS.output_dir)

        tpu_cluster_resolver = None
        if FLAGS.use_tpu and FLAGS.tpu_name:
            tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
                FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        config = tf.compat.v1.ConfigProto(allow_soft_placement=False, )
        is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2
        run_config = tf.compat.v1.estimator.tpu.RunConfig(
            cluster=tpu_cluster_resolver,
            master=FLAGS.master,
            model_dir=FLAGS.output_dir,
            save_checkpoints_steps=FLAGS.save_checkpoints_steps,
            keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
            session_config=config,
            tpu_config=tf.compat.v1.estimator.tpu.TPUConfig(
                iterations_per_loop=FLAGS.iterations_per_loop,
                num_shards=FLAGS.num_tpu_cores,
                per_host_input_for_training=is_per_host))

        # If TPU is not available, this will fall back to normal Estimator on CPU
        # or GPU.
        self.estimator = tf.compat.v1.estimator.tpu.TPUEstimator(
            use_tpu=FLAGS.use_tpu,
            model_fn=model_fn,
            config=run_config,
            train_batch_size=FLAGS.train_batch_size,
            eval_batch_size=FLAGS.eval_batch_size,
            predict_batch_size=FLAGS.eval_batch_size,
        )
Ejemplo n.º 19
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    model_fn = model_fn_sensitivity(
        bert_config=bert_config,
        train_config=train_config,
        model_class=BertModel,
        special_flags=special_flags,
    )
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    input_fn = input_fn_builder_use_second_input(FLAGS)
    result = run_estimator(model_fn, input_fn)
    return result
def main(_):
    input_files = get_input_files_from_flags(FLAGS)

    show_input_files(input_files)

    if FLAGS.do_predict:
        model_fn = model_fn_rank_pred(FLAGS)
        input_fn = input_fn_builder_prediction_w_data_id(
            input_files=input_files, max_seq_length=FLAGS.max_seq_length)
    else:
        assert False

    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    result = run_estimator(model_fn, input_fn)
    return result
Ejemplo n.º 21
0
def run_w_data_id():
    input_files = get_input_files_from_flags(FLAGS)
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    model_fn = model_fn_classification(
        bert_config=bert_config,
        train_config=train_config,
        model_class=BertModel,
        special_flags=special_flags,
    )
    if FLAGS.do_predict:
        tf_logging.addFilter(CounterFilter())
    input_fn = input_fn_builder_classification_w_data_ids_typo(
        input_files=input_files, flags=FLAGS, is_training=FLAGS.do_train)
    result = run_estimator(model_fn, input_fn)
    return result
Ejemplo n.º 22
0
def main(_):
    tf_logging.info("DualBertTwoInputModel Two learning rate")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_builder_two_inputs_w_data_id(FLAGS)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")

    def lr_group_b(name):
        return dual_model_prefix1 in name

    lr_factor = 1 / 50
    model_fn = model_fn_classification(config, train_config,
                                       DualBertTwoInputModel, special_flags,
                                       lr_group_b, lr_factor)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Ejemplo n.º 23
0
def main(_):
    tf_logging.info("Multi-evidence with Dot Product (CK version)")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    total_doc_length = config.max_doc_length * config.num_docs
    input_fn = input_fn_builder_dot_product_ck(FLAGS, config.max_sent_length, total_doc_length)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    modeling = FLAGS.modeling
    if modeling is None:
        model_class = BertModel
    elif modeling == "pmp":
        model_class = ProjectedMaxPooling
    else:
        assert False

    model_fn = model_fn_pointwise_ranking(config, train_config, model_class, special_flags)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Ejemplo n.º 24
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")

    def override_prediction_fn(predictions, model):
        predictions['vector'] = model.get_output()
        return predictions

    model_fn = model_fn_classification(
        bert_config=bert_config,
        train_config=train_config,
        model_class=MultiEvidenceUseFirst,
        special_flags=special_flags,
        override_prediction_fn=override_prediction_fn)
    if FLAGS.do_predict:
        tf_logging.addFilter(CounterFilter())
    input_fn = input_fn_builder_use_second_input(FLAGS)
    result = run_estimator(model_fn, input_fn)
    return result
Ejemplo n.º 25
0
def run_estimator(model_fn, input_fn, host_call=None):
    tf_logging.setLevel(logging.INFO)
    if FLAGS.log_debug:
        tf_logging.setLevel(logging.DEBUG)
    #FLAGS.init_checkpoint = auto_resolve_init_checkpoint(FLAGS.init_checkpoint)
    tf.io.gfile.makedirs(FLAGS.output_dir)
    if FLAGS.do_predict:
        tf_logging.addFilter(CounterFilter())

    tpu_cluster_resolver = None
    config = tf.compat.v1.ConfigProto(allow_soft_placement=False, )
    is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.compat.v1.estimator.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        session_config=config,
        tf_random_seed=FLAGS.random_seed,
    )

    if FLAGS.random_seed is not None:
        tf_logging.info("Using random seed : {}".format(FLAGS.random_seed))
        tf.random.set_seed(FLAGS.random_seed)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.compat.v1.estimator.Estimator(model_fn=model_fn,
                                                 config=run_config,
                                                 params={'batch_size': 16})

    if FLAGS.do_train:
        tf_logging.info("***** Running training *****")
        tf_logging.info("  Batch size = %d", FLAGS.train_batch_size)

        estimator.train(input_fn=input_fn, max_steps=FLAGS.num_train_steps)
Ejemplo n.º 26
0
def main(_):
    tf_logging.info("DualBertTwoInputModel simple prediction")
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    input_fn = input_fn_builder_two_inputs_w_data_id(FLAGS)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")

    def override_prediction_fn(predictions, model):
        predictions.pop('input_ids', None)
        try:
            predictions.pop('input_ids2', None)
        except KeyError:
            pass
        return predictions

    model_fn = model_fn_classification(config, train_config,
                                       DualBertTwoInputModel, special_flags,
                                       override_prediction_fn)
    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())

    return run_estimator(model_fn, input_fn)
Ejemplo n.º 27
0
def main(_):
    tf_logging.addFilter(CounterFilter())
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    input_files = flags_wrapper.get_input_files()
    train_config = TrainConfigEx.from_flags(FLAGS)

    show_input_files(input_files)

    model_fn = model_fn_try_all_loss(
        bert_config=bert_config,
        train_config=train_config,
        logging=tf_logging,
    )
    if FLAGS.do_predict:
        input_fn = input_fn_builder_unmasked(
            input_files=input_files,
            flags=FLAGS,
            is_training=False)
    else:
        assert False

    r = run_estimator(model_fn, input_fn)
    return r
Ejemplo n.º 28
0
def main(_):
    input_files = get_input_files_from_flags(FLAGS)
    config = JsonConfig.from_json_file(FLAGS.model_config_file)
    train_config = TrainConfigEx.from_flags(FLAGS)
    show_input_files(input_files)
    special_flags = FLAGS.special_flags.split(",")
    special_flags.append("feed_features")
    is_training = FLAGS.do_train
    model_fn = model_fn_binary_classification_loss(
        config,
        train_config,
        MES_const_0_handle,
    )
    input_fn = input_fn_builder_classification(input_files,
                                               FLAGS.max_d_seq_length,
                                               is_training,
                                               FLAGS,
                                               num_cpu_threads=4,
                                               repeat_for_eval=False)

    if FLAGS.do_predict:
        tf_logging.addFilter(MuteEnqueueFilter())
    result = run_estimator(model_fn, input_fn)
    return result
Ejemplo n.º 29
0
def run_estimator(model_fn, input_fn, host_call=None):
    tf_logging.setLevel(logging.INFO)
    if FLAGS.log_debug:
        tf_logging.setLevel(logging.DEBUG)
    #FLAGS.init_checkpoint = auto_resolve_init_checkpoint(FLAGS.init_checkpoint)
    tf.io.gfile.makedirs(FLAGS.output_dir)
    if FLAGS.do_predict:
        tf_logging.addFilter(CounterFilter())

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    print("FLAGS.save_checkpoints_steps", FLAGS.save_checkpoints_steps)
    config = tf.compat.v1.ConfigProto(allow_soft_placement=False, )
    is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.compat.v1.estimator.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        session_config=config,
        tf_random_seed=FLAGS.random_seed,
        tpu_config=tf.compat.v1.estimator.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host))

    if FLAGS.random_seed is not None:
        tf_logging.info("Using random seed : {}".format(FLAGS.random_seed))
        tf.random.set_seed(FLAGS.random_seed)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.compat.v1.estimator.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        predict_batch_size=FLAGS.eval_batch_size,
    )

    if FLAGS.do_train:
        tf_logging.info("***** Running training *****")
        tf_logging.info("  Batch size = %d", FLAGS.train_batch_size)

        estimator.train(input_fn=input_fn, max_steps=FLAGS.num_train_steps)

    if FLAGS.do_eval:
        tf_logging.info("***** Running evaluation *****")
        tf_logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        if FLAGS.initialize_to_predict:
            checkpoint = FLAGS.init_checkpoint
        else:
            checkpoint = None
        result = estimator.evaluate(input_fn=input_fn,
                                    steps=FLAGS.max_eval_steps,
                                    checkpoint_path=checkpoint)

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.io.gfile.GFile(output_eval_file, "w") as writer:
            tf_logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf_logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
        return result
    if FLAGS.do_predict:
        tf_logging.info("***** Running prediction *****")
        tf_logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        if not FLAGS.initialize_to_predict:
            verify_checkpoint(estimator.model_dir)
            checkpoint = None
            time.sleep(1)
        else:
            checkpoint = FLAGS.init_checkpoint

        result = estimator.predict(input_fn=input_fn,
                                   checkpoint_path=checkpoint,
                                   yield_single_examples=False)
        pickle.dump(list(result), open(FLAGS.out_file, "wb"))
        tf_logging.info("Prediction saved at {}".format(FLAGS.out_file))