def old_main(): bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) input_files = [] for input_pattern in FLAGS.input_file.split(","): input_files.extend(tf.io.gfile.glob(input_pattern)) train_config = TrainConfigEx.from_flags(FLAGS) model_config = JsonConfig.from_json_file(FLAGS.model_config_file) tf_logging.addFilter(CounterFilter()) output_names = [] input_fn_list = [] for input_file in input_files: name = input_file.split("/")[-1] output_name = "disk_output/loss_predictor_predictions/" + name input_fn = input_fn_builder_unmasked(input_files=[input_file], flags=FLAGS, is_training=False) input_fn_list.append(input_fn) output_names.append(output_name) model_fn = loss_diff_predict_only_model_fn( bert_config=bert_config, train_config=train_config, model_class=BertModel, model_config=model_config, ) if FLAGS.do_predict: run_estimator_loop(model_fn, input_fn_list, output_names) else: raise Exception("Only PREDICT mode is allowed")
def main(_): input_files = get_input_files_from_flags(FLAGS) config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) special_flags = FLAGS.special_flags.split(",") special_flags.append("feed_features") is_training = FLAGS.do_train if FLAGS.do_train or FLAGS.do_eval: model_fn = model_fn_with_loss( config, train_config, MES_hinge, ) input_fn = input_fn_builder_pairwise(FLAGS.max_d_seq_length, FLAGS) else: model_fn = model_fn_with_loss( config, train_config, MES_pred, ) input_fn = input_fn_builder_classification_w_data_id2( input_files, FLAGS.max_seq_length, FLAGS, is_training, num_cpu_threads=4) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) result = run_estimator(model_fn, input_fn) return result
def main(_): input_files = get_input_files_from_flags(FLAGS) config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) special_flags = FLAGS.special_flags.split(",") model_fn = model_fn_classification( config, train_config, BertModel, special_flags ) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) is_training = FLAGS.do_train if FLAGS.do_train or FLAGS.do_eval: input_fn = input_fn_builder_classification(input_files, FLAGS.max_seq_length, is_training, FLAGS, num_cpu_threads=4, repeat_for_eval=False) else: input_fn = input_fn_builder_classification_w_data_id2( input_files, FLAGS.max_seq_length, FLAGS, is_training, num_cpu_threads=4) result = run_estimator(model_fn, input_fn) return result
def main(_): tf_logging.info("TripleBertMasking") config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) input_fn = input_fn_builder_cppnc_triple(FLAGS) special_flags = FLAGS.special_flags.split(",") special_flags.append("feed_features") def override_prediction_fn(predictions, model): for key, value in model.get_predictions().items(): predictions[key] = value return predictions if FLAGS.modeling == "TripleBertMasking": model_class = TripleBertMasking elif FLAGS.modeling == "TripleBertWeighted": model_class = TripleBertWeighted else: assert False model_fn = model_fn_classification(config, train_config, model_class, special_flags, override_prediction_fn) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) return run_estimator(model_fn, input_fn)
def main(_): tf_logging.info("Classification with confidence") config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) input_fn = input_fn_builder_two_inputs_w_data_id(FLAGS) model_fn = model_fn_classification_with_confidence(config, train_config) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) return run_estimator(model_fn, input_fn)
def main(_): tf_logging.info("Classification with alt loss") config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) input_fn = input_fn_from_flags(input_fn_builder_classification, FLAGS) model_fn = model_fn_classification_with_alt_loss(config, train_config, BertModel) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) return run_estimator(model_fn, input_fn)
def main(_): tf_logging.info("Regression with weigth") config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) input_fn = input_fn_builder_regression(get_input_files_from_flags(FLAGS), FLAGS, FLAGS.do_train) model_fn = model_fn_regression(config, train_config) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) return run_estimator(model_fn, input_fn)
def main(_): tf_logging.info("Manual combiner") config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) input_fn = input_fn_builder_two_inputs_w_data_id(FLAGS) special_flags = FLAGS.special_flags.split(",") special_flags.append("feed_features") model_fn = model_fn_classification_manual_combiner(config, train_config) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) return run_estimator(model_fn, input_fn)
def main(_): tf_logging.info("QCK with ME7") config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) input_fn = input_fn_builder_cppnc_multi_evidence(FLAGS) special_flags = FLAGS.special_flags.split(",") special_flags.append("feed_features") model_fn = model_fn_classification(config, train_config, ME7, special_flags) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) return run_estimator(model_fn, input_fn)
def main(_): tf_logging.info("Multi-evidence for QK") config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) input_fn = input_fn_builder_vector_ck(FLAGS, config) special_flags = [] model_fn = vector_combining_model(config, train_config, special_flags) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) return run_estimator(model_fn, input_fn)
def main(_): tf_logging.info("QCK with rel") config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) input_fn = input_fn_builder_two_inputs_w_rel(FLAGS) special_flags = FLAGS.special_flags.split(",") special_flags.append("feed_features") model_fn = model_fn_classification(config, train_config, DualBertTwoInputWRel, special_flags) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) return run_estimator(model_fn, input_fn)
def main(_): tf_logging.info("DualBertTwoInputWithDoubleInputLength") config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) input_fn = input_fn_builder_dual_bert_double_length_input(FLAGS) special_flags = FLAGS.special_flags.split(",") special_flags.append("feed_features") model_fn = model_fn_classification(config, train_config, DualBertTwoInputWithDoubleInputLength, special_flags) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) return run_estimator(model_fn, input_fn)
def run_classification_w_second_input(): input_files = get_input_files_from_flags(FLAGS) config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) special_flags = FLAGS.special_flags.split(",") special_flags.append("feed_features") model_fn = model_fn_classification(config, train_config, ME5_2, special_flags) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) input_fn = input_fn_builder_cppnc_multi_evidence(FLAGS) result = run_estimator(model_fn, input_fn) return result
def run_classification_w_second_input(): input_files = get_input_files_from_flags(FLAGS) bert_config = BertConfig.from_json_file(FLAGS.bert_config_file) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) model_fn = model_fn_classification( bert_config, train_config, ) input_fn = input_fn_builder_use_second_input(FLAGS) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) result = run_estimator(model_fn, input_fn) return result
def run_w_data_id(): input_files = get_input_files_from_flags(FLAGS) bert_config = BertConfig.from_json_file(FLAGS.bert_config_file) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) model_fn = model_fn_classification_weighted_loss( bert_config, train_config, ) if FLAGS.do_predict: tf_logging.addFilter(CounterFilter()) input_fn = input_fn_builder_classification_w_data_id( input_files=input_files, flags=FLAGS, is_training=FLAGS.do_train) result = run_estimator(model_fn, input_fn) return result
def main(_): tf_logging.info("Token scoring") config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) input_files = get_input_files_from_flags(FLAGS) show_input_files(input_files) input_fn = input_fn_token_scoring2(input_files, FLAGS, FLAGS.do_train) special_flags = FLAGS.special_flags.split(",") special_flags.append("feed_features") model_fn = model_fn_token_scoring(config, train_config) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) return run_estimator(model_fn, input_fn)
def main(_): input_files = get_input_files_from_flags(FLAGS) config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) special_flags = FLAGS.special_flags.split(",") special_flags.append("feed_features") model_fn = model_fn_classification_with_ada( config, train_config, ) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) input_fn = input_fn_builder_ada(input_files, FLAGS, FLAGS.do_train) result = run_estimator(model_fn, input_fn) return result
def __init__(self, out_dir, input_gs_dir): self.output_dir = out_dir self.input_gs_dir = input_gs_dir bert_config = modeling.BertConfig.from_json_file( FLAGS.bert_config_file) train_config = TrainConfigEx.from_flags(FLAGS) model_config = JsonConfig.from_json_file(FLAGS.model_config_file) tf_logging.addFilter(CounterFilter()) model_fn = loss_diff_predict_only_model_fn( bert_config=bert_config, train_config=train_config, model_class=BertModel, model_config=model_config, ) tf.io.gfile.makedirs(FLAGS.output_dir) tpu_cluster_resolver = None if FLAGS.use_tpu and FLAGS.tpu_name: tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) config = tf.compat.v1.ConfigProto(allow_soft_placement=False, ) is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2 run_config = tf.compat.v1.estimator.tpu.RunConfig( cluster=tpu_cluster_resolver, master=FLAGS.master, model_dir=FLAGS.output_dir, save_checkpoints_steps=FLAGS.save_checkpoints_steps, keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours, session_config=config, tpu_config=tf.compat.v1.estimator.tpu.TPUConfig( iterations_per_loop=FLAGS.iterations_per_loop, num_shards=FLAGS.num_tpu_cores, per_host_input_for_training=is_per_host)) # If TPU is not available, this will fall back to normal Estimator on CPU # or GPU. self.estimator = tf.compat.v1.estimator.tpu.TPUEstimator( use_tpu=FLAGS.use_tpu, model_fn=model_fn, config=run_config, train_batch_size=FLAGS.train_batch_size, eval_batch_size=FLAGS.eval_batch_size, predict_batch_size=FLAGS.eval_batch_size, )
def main(_): input_files = get_input_files_from_flags(FLAGS) bert_config = BertConfig.from_json_file(FLAGS.bert_config_file) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) special_flags = FLAGS.special_flags.split(",") model_fn = model_fn_sensitivity( bert_config=bert_config, train_config=train_config, model_class=BertModel, special_flags=special_flags, ) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) input_fn = input_fn_builder_use_second_input(FLAGS) result = run_estimator(model_fn, input_fn) return result
def main(_): input_files = get_input_files_from_flags(FLAGS) show_input_files(input_files) if FLAGS.do_predict: model_fn = model_fn_rank_pred(FLAGS) input_fn = input_fn_builder_prediction_w_data_id( input_files=input_files, max_seq_length=FLAGS.max_seq_length) else: assert False if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) result = run_estimator(model_fn, input_fn) return result
def run_w_data_id(): input_files = get_input_files_from_flags(FLAGS) bert_config = BertConfig.from_json_file(FLAGS.bert_config_file) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) special_flags = FLAGS.special_flags.split(",") model_fn = model_fn_classification( bert_config=bert_config, train_config=train_config, model_class=BertModel, special_flags=special_flags, ) if FLAGS.do_predict: tf_logging.addFilter(CounterFilter()) input_fn = input_fn_builder_classification_w_data_ids_typo( input_files=input_files, flags=FLAGS, is_training=FLAGS.do_train) result = run_estimator(model_fn, input_fn) return result
def main(_): tf_logging.info("DualBertTwoInputModel Two learning rate") config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) input_fn = input_fn_builder_two_inputs_w_data_id(FLAGS) special_flags = FLAGS.special_flags.split(",") special_flags.append("feed_features") def lr_group_b(name): return dual_model_prefix1 in name lr_factor = 1 / 50 model_fn = model_fn_classification(config, train_config, DualBertTwoInputModel, special_flags, lr_group_b, lr_factor) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) return run_estimator(model_fn, input_fn)
def main(_): tf_logging.info("Multi-evidence with Dot Product (CK version)") config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) total_doc_length = config.max_doc_length * config.num_docs input_fn = input_fn_builder_dot_product_ck(FLAGS, config.max_sent_length, total_doc_length) special_flags = FLAGS.special_flags.split(",") special_flags.append("feed_features") modeling = FLAGS.modeling if modeling is None: model_class = BertModel elif modeling == "pmp": model_class = ProjectedMaxPooling else: assert False model_fn = model_fn_pointwise_ranking(config, train_config, model_class, special_flags) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) return run_estimator(model_fn, input_fn)
def main(_): input_files = get_input_files_from_flags(FLAGS) bert_config = BertConfig.from_json_file(FLAGS.bert_config_file) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) special_flags = FLAGS.special_flags.split(",") def override_prediction_fn(predictions, model): predictions['vector'] = model.get_output() return predictions model_fn = model_fn_classification( bert_config=bert_config, train_config=train_config, model_class=MultiEvidenceUseFirst, special_flags=special_flags, override_prediction_fn=override_prediction_fn) if FLAGS.do_predict: tf_logging.addFilter(CounterFilter()) input_fn = input_fn_builder_use_second_input(FLAGS) result = run_estimator(model_fn, input_fn) return result
def run_estimator(model_fn, input_fn, host_call=None): tf_logging.setLevel(logging.INFO) if FLAGS.log_debug: tf_logging.setLevel(logging.DEBUG) #FLAGS.init_checkpoint = auto_resolve_init_checkpoint(FLAGS.init_checkpoint) tf.io.gfile.makedirs(FLAGS.output_dir) if FLAGS.do_predict: tf_logging.addFilter(CounterFilter()) tpu_cluster_resolver = None config = tf.compat.v1.ConfigProto(allow_soft_placement=False, ) is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2 run_config = tf.compat.v1.estimator.tpu.RunConfig( cluster=tpu_cluster_resolver, master=FLAGS.master, model_dir=FLAGS.output_dir, save_checkpoints_steps=FLAGS.save_checkpoints_steps, keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours, keep_checkpoint_max=FLAGS.keep_checkpoint_max, session_config=config, tf_random_seed=FLAGS.random_seed, ) if FLAGS.random_seed is not None: tf_logging.info("Using random seed : {}".format(FLAGS.random_seed)) tf.random.set_seed(FLAGS.random_seed) # If TPU is not available, this will fall back to normal Estimator on CPU # or GPU. estimator = tf.compat.v1.estimator.Estimator(model_fn=model_fn, config=run_config, params={'batch_size': 16}) if FLAGS.do_train: tf_logging.info("***** Running training *****") tf_logging.info(" Batch size = %d", FLAGS.train_batch_size) estimator.train(input_fn=input_fn, max_steps=FLAGS.num_train_steps)
def main(_): tf_logging.info("DualBertTwoInputModel simple prediction") config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) input_fn = input_fn_builder_two_inputs_w_data_id(FLAGS) special_flags = FLAGS.special_flags.split(",") special_flags.append("feed_features") def override_prediction_fn(predictions, model): predictions.pop('input_ids', None) try: predictions.pop('input_ids2', None) except KeyError: pass return predictions model_fn = model_fn_classification(config, train_config, DualBertTwoInputModel, special_flags, override_prediction_fn) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) return run_estimator(model_fn, input_fn)
def main(_): tf_logging.addFilter(CounterFilter()) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) input_files = flags_wrapper.get_input_files() train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) model_fn = model_fn_try_all_loss( bert_config=bert_config, train_config=train_config, logging=tf_logging, ) if FLAGS.do_predict: input_fn = input_fn_builder_unmasked( input_files=input_files, flags=FLAGS, is_training=False) else: assert False r = run_estimator(model_fn, input_fn) return r
def main(_): input_files = get_input_files_from_flags(FLAGS) config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) special_flags = FLAGS.special_flags.split(",") special_flags.append("feed_features") is_training = FLAGS.do_train model_fn = model_fn_binary_classification_loss( config, train_config, MES_const_0_handle, ) input_fn = input_fn_builder_classification(input_files, FLAGS.max_d_seq_length, is_training, FLAGS, num_cpu_threads=4, repeat_for_eval=False) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) result = run_estimator(model_fn, input_fn) return result
def run_estimator(model_fn, input_fn, host_call=None): tf_logging.setLevel(logging.INFO) if FLAGS.log_debug: tf_logging.setLevel(logging.DEBUG) #FLAGS.init_checkpoint = auto_resolve_init_checkpoint(FLAGS.init_checkpoint) tf.io.gfile.makedirs(FLAGS.output_dir) if FLAGS.do_predict: tf_logging.addFilter(CounterFilter()) tpu_cluster_resolver = None if FLAGS.use_tpu and FLAGS.tpu_name: tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) print("FLAGS.save_checkpoints_steps", FLAGS.save_checkpoints_steps) config = tf.compat.v1.ConfigProto(allow_soft_placement=False, ) is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2 run_config = tf.compat.v1.estimator.tpu.RunConfig( cluster=tpu_cluster_resolver, master=FLAGS.master, model_dir=FLAGS.output_dir, save_checkpoints_steps=FLAGS.save_checkpoints_steps, keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours, keep_checkpoint_max=FLAGS.keep_checkpoint_max, session_config=config, tf_random_seed=FLAGS.random_seed, tpu_config=tf.compat.v1.estimator.tpu.TPUConfig( iterations_per_loop=FLAGS.iterations_per_loop, num_shards=FLAGS.num_tpu_cores, per_host_input_for_training=is_per_host)) if FLAGS.random_seed is not None: tf_logging.info("Using random seed : {}".format(FLAGS.random_seed)) tf.random.set_seed(FLAGS.random_seed) # If TPU is not available, this will fall back to normal Estimator on CPU # or GPU. estimator = tf.compat.v1.estimator.tpu.TPUEstimator( use_tpu=FLAGS.use_tpu, model_fn=model_fn, config=run_config, train_batch_size=FLAGS.train_batch_size, eval_batch_size=FLAGS.eval_batch_size, predict_batch_size=FLAGS.eval_batch_size, ) if FLAGS.do_train: tf_logging.info("***** Running training *****") tf_logging.info(" Batch size = %d", FLAGS.train_batch_size) estimator.train(input_fn=input_fn, max_steps=FLAGS.num_train_steps) if FLAGS.do_eval: tf_logging.info("***** Running evaluation *****") tf_logging.info(" Batch size = %d", FLAGS.eval_batch_size) if FLAGS.initialize_to_predict: checkpoint = FLAGS.init_checkpoint else: checkpoint = None result = estimator.evaluate(input_fn=input_fn, steps=FLAGS.max_eval_steps, checkpoint_path=checkpoint) output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") with tf.io.gfile.GFile(output_eval_file, "w") as writer: tf_logging.info("***** Eval results *****") for key in sorted(result.keys()): tf_logging.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) return result if FLAGS.do_predict: tf_logging.info("***** Running prediction *****") tf_logging.info(" Batch size = %d", FLAGS.eval_batch_size) if not FLAGS.initialize_to_predict: verify_checkpoint(estimator.model_dir) checkpoint = None time.sleep(1) else: checkpoint = FLAGS.init_checkpoint result = estimator.predict(input_fn=input_fn, checkpoint_path=checkpoint, yield_single_examples=False) pickle.dump(list(result), open(FLAGS.out_file, "wb")) tf_logging.info("Prediction saved at {}".format(FLAGS.out_file))