def input_fn_perspective_passage(flags): input_files = get_input_files_from_flags(flags) max_seq_length = flags.max_seq_length show_input_files(input_files) is_training = flags.do_train num_cpu_threads = 4 def input_fn(params): """The actual input function.""" batch_size = params["batch_size"] name_to_features = dict({ "input_ids1": tf.io.FixedLenFeature([max_seq_length], tf.int64), "input_mask1": tf.io.FixedLenFeature([max_seq_length], tf.int64), "segment_ids1": tf.io.FixedLenFeature([max_seq_length], tf.int64), "input_ids2": tf.io.FixedLenFeature([max_seq_length], tf.int64), "input_mask2": tf.io.FixedLenFeature([max_seq_length], tf.int64), "segment_ids2": tf.io.FixedLenFeature([max_seq_length], tf.int64), }) name_to_features["strict_good"] = tf.io.FixedLenFeature([1], tf.int64) name_to_features["strict_bad"] = tf.io.FixedLenFeature([1], tf.int64) return format_dataset(name_to_features, batch_size, is_training, flags, input_files, num_cpu_threads) return input_fn
def main(_): input_files = get_input_files_from_flags(FLAGS) config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) special_flags = FLAGS.special_flags.split(",") special_flags.append("feed_features") is_training = FLAGS.do_train if FLAGS.do_train or FLAGS.do_eval: model_fn = model_fn_with_loss( config, train_config, MES_hinge, ) input_fn = input_fn_builder_pairwise(FLAGS.max_d_seq_length, FLAGS) else: model_fn = model_fn_with_loss( config, train_config, MES_pred, ) input_fn = input_fn_builder_classification_w_data_id2( input_files, FLAGS.max_seq_length, FLAGS, is_training, num_cpu_threads=4) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) result = run_estimator(model_fn, input_fn) return result
def main(_): bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) input_files = [] for input_pattern in FLAGS.input_file.split(","): input_files.extend(tf.io.gfile.glob(input_pattern)) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) ssdr_config = JsonConfig.from_json_file(FLAGS.model_config_file) model_fn = model_fn_apr_classification( bert_config=bert_config, ssdr_config=ssdr_config, train_config=train_config, dict_run_config=DictRunConfig.from_flags(FLAGS), ) if FLAGS.do_train: input_fn = input_fn_builder(input_files=input_files, max_seq_length=FLAGS.max_seq_length, is_training=True) elif FLAGS.do_eval or FLAGS.do_predict: input_fn = input_fn_builder(input_files=input_files, max_seq_length=FLAGS.max_seq_length, is_training=False) else: raise Exception() return run_estimator(model_fn, input_fn)
def main(_): bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) input_files = [] for input_pattern in FLAGS.input_file.split(","): input_files.extend(tf.io.gfile.glob(input_pattern)) train_config = TrainConfigEx.from_flags(FLAGS) model_config = JsonConfig.from_json_file(FLAGS.model_config_file) show_input_files(input_files) model_fn = loss_diff_prediction_model( bert_config=bert_config, train_config=train_config, model_class=BertModel, model_config=model_config, ) if FLAGS.do_train: input_fn = input_fn_builder_masked( input_files=input_files, flags=FLAGS, is_training=True) elif FLAGS.do_eval or FLAGS.do_predict: input_fn = input_fn_builder_masked( input_files=input_files, flags=FLAGS, is_training=False) else: raise Exception() run_estimator(model_fn, input_fn)
def main(_): input_files = get_input_files_from_flags(FLAGS) config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) special_flags = FLAGS.special_flags.split(",") model_fn = model_fn_classification( config, train_config, BertModel, special_flags ) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) is_training = FLAGS.do_train if FLAGS.do_train or FLAGS.do_eval: input_fn = input_fn_builder_classification(input_files, FLAGS.max_seq_length, is_training, FLAGS, num_cpu_threads=4, repeat_for_eval=False) else: input_fn = input_fn_builder_classification_w_data_id2( input_files, FLAGS.max_seq_length, FLAGS, is_training, num_cpu_threads=4) result = run_estimator(model_fn, input_fn) return result
def main(_): bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) input_files = flags_wrapper.get_input_files() train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) model_fn = model_fn_explain( bert_config=bert_config, train_config=train_config, logging=tf_logging, ) input_fn = input_fn_from_flags(input_fn_builder_classification, FLAGS) r = run_estimator(model_fn, input_fn) return r
def main(_): input_files = get_input_files_from_flags(FLAGS) show_input_files(input_files) if FLAGS.do_predict: model_fn = model_fn_rank_pred(FLAGS) input_fn = input_fn_builder_prediction( input_files=input_files, max_seq_length=FLAGS.max_seq_length) else: assert False result = run_estimator(model_fn, input_fn) return result
def run_classification_w_second_input(): input_files = get_input_files_from_flags(FLAGS) config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) special_flags = FLAGS.special_flags.split(",") special_flags.append("feed_features") model_fn = model_fn_classification(config, train_config, ME5_2, special_flags) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) input_fn = input_fn_builder_cppnc_multi_evidence(FLAGS) result = run_estimator(model_fn, input_fn) return result
def run_classification_w_second_input(): input_files = get_input_files_from_flags(FLAGS) bert_config = BertConfig.from_json_file(FLAGS.bert_config_file) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) model_fn = model_fn_classification( bert_config, train_config, ) input_fn = input_fn_builder_use_second_input(FLAGS) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) result = run_estimator(model_fn, input_fn) return result
def main(_): bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) input_files = flags_wrapper.get_input_files() train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) model_fn = model_fn_explain( bert_config=bert_config, train_config=train_config, logging=tf_logging, ) is_training = FLAGS.do_train input_fn = input_fn_builder(input_files, FLAGS.max_seq_length, is_training) r = run_estimator(model_fn, input_fn) return r
def run_w_data_id(): input_files = get_input_files_from_flags(FLAGS) bert_config = BertConfig.from_json_file(FLAGS.bert_config_file) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) model_fn = model_fn_classification_weighted_loss( bert_config, train_config, ) if FLAGS.do_predict: tf_logging.addFilter(CounterFilter()) input_fn = input_fn_builder_classification_w_data_id( input_files=input_files, flags=FLAGS, is_training=FLAGS.do_train) result = run_estimator(model_fn, input_fn) return result
def main(_): tf_logging.info("Token scoring") config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) input_files = get_input_files_from_flags(FLAGS) show_input_files(input_files) input_fn = input_fn_token_scoring2(input_files, FLAGS, FLAGS.do_train) special_flags = FLAGS.special_flags.split(",") special_flags.append("feed_features") model_fn = model_fn_token_scoring(config, train_config) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) return run_estimator(model_fn, input_fn)
def main(_): input_files = get_input_files_from_flags(FLAGS) config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) special_flags = FLAGS.special_flags.split(",") special_flags.append("feed_features") model_fn = model_fn_classification_with_ada( config, train_config, ) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) input_fn = input_fn_builder_ada(input_files, FLAGS, FLAGS.do_train) result = run_estimator(model_fn, input_fn) return result
def main(_): input_files = get_input_files_from_flags(FLAGS) bert_config = BertConfig.from_json_file(FLAGS.bert_config_file) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) special_flags = FLAGS.special_flags.split(",") model_fn = model_fn_sensitivity( bert_config=bert_config, train_config=train_config, model_class=BertModel, special_flags=special_flags, ) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) input_fn = input_fn_builder_use_second_input(FLAGS) result = run_estimator(model_fn, input_fn) return result
def main(_): input_files = get_input_files_from_flags(FLAGS) show_input_files(input_files) if FLAGS.do_predict: model_fn = model_fn_rank_pred(FLAGS) input_fn = input_fn_builder_prediction_w_data_id( input_files=input_files, max_seq_length=FLAGS.max_seq_length) else: assert False if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) result = run_estimator(model_fn, input_fn) return result
def main(_): input_files = get_input_files_from_flags(FLAGS) bert_config = BertConfig.from_json_file(FLAGS.bert_config_file) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) special_flags = FLAGS.special_flags.split(",") model_fn = model_fn_classification( bert_config=bert_config, train_config=train_config, model_class=FreezeEmbedding, special_flags=special_flags, ) input_fn = input_fn_builder_classification_w_data_id( input_files=input_files, flags=FLAGS, is_training=FLAGS.do_train) result = run_estimator(model_fn, input_fn) return result
def run_w_data_id(): input_files = get_input_files_from_flags(FLAGS) bert_config = BertConfig.from_json_file(FLAGS.bert_config_file) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) special_flags = FLAGS.special_flags.split(",") model_fn = model_fn_classification( bert_config=bert_config, train_config=train_config, model_class=BertModel, special_flags=special_flags, ) if FLAGS.do_predict: tf_logging.addFilter(CounterFilter()) input_fn = input_fn_builder_classification_w_data_ids_typo( input_files=input_files, flags=FLAGS, is_training=FLAGS.do_train) result = run_estimator(model_fn, input_fn) return result
def main(_): bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) input_files = [] for input_pattern in FLAGS.input_file.split(","): input_files.extend(tf.io.gfile.glob(input_pattern)) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) model_fn = model_fn_preserved_dim( bert_config=bert_config, train_config=train_config, ) if FLAGS.do_predict: input_fn = input_fn_builder_unmasked(input_files=input_files, flags=FLAGS, is_training=False) else: raise Exception("Only PREDICT mode is allowed") run_estimator(model_fn, input_fn)
def input_fn_builder_multi_context_classification(max_seq_length, max_context, max_context_length, flags): input_files = get_input_files_from_flags(flags) show_input_files(input_files) is_training = flags.do_train num_cpu_threads = 4 raw_context_len = max_context * max_context_length def input_fn(params): """The actual input function.""" batch_size = params["batch_size"] name_to_features = { "input_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64), "input_mask": tf.io.FixedLenFeature([max_seq_length], tf.int64), "segment_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64), "context_input_ids": tf.io.FixedLenFeature([raw_context_len], tf.int64), "context_input_mask": tf.io.FixedLenFeature([raw_context_len], tf.int64), "context_segment_ids": tf.io.FixedLenFeature([raw_context_len], tf.int64), "label_ids": tf.io.FixedLenFeature([1], tf.int64), } return format_dataset(name_to_features, batch_size, is_training, flags, input_files, num_cpu_threads) return input_fn
def main(_): bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) input_files = flags_wrapper.get_input_files() train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) model_fn = model_fn_tlm_debug( bert_config=bert_config, train_config=train_config, logging=tf_logging, model_class=BertModel, ) if FLAGS.do_predict: input_fn = input_fn_builder_unmasked(input_files=input_files, flags=FLAGS, is_training=False) else: assert False r = run_estimator(model_fn, input_fn) return r
def main(_): input_files = get_input_files_from_flags(FLAGS) bert_config = BertConfig.from_json_file(FLAGS.bert_config_file) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) special_flags = FLAGS.special_flags.split(",") def override_prediction_fn(predictions, model): predictions['vector'] = model.get_output() return predictions model_fn = model_fn_classification( bert_config=bert_config, train_config=train_config, model_class=MultiEvidenceUseFirst, special_flags=special_flags, override_prediction_fn=override_prediction_fn) if FLAGS.do_predict: tf_logging.addFilter(CounterFilter()) input_fn = input_fn_builder_use_second_input(FLAGS) result = run_estimator(model_fn, input_fn) return result
def main(_): tf_logging.addFilter(CounterFilter()) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) input_files = flags_wrapper.get_input_files() train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) model_fn = model_fn_try_all_loss( bert_config=bert_config, train_config=train_config, logging=tf_logging, ) if FLAGS.do_predict: input_fn = input_fn_builder_unmasked( input_files=input_files, flags=FLAGS, is_training=False) else: assert False r = run_estimator(model_fn, input_fn) return r
def main(_): input_files = get_input_files_from_flags(FLAGS) config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) special_flags = FLAGS.special_flags.split(",") special_flags.append("feed_features") is_training = FLAGS.do_train model_fn = model_fn_binary_classification_loss( config, train_config, MES_const_0_handle, ) input_fn = input_fn_builder_classification(input_files, FLAGS.max_d_seq_length, is_training, FLAGS, num_cpu_threads=4, repeat_for_eval=False) if FLAGS.do_predict: tf_logging.addFilter(MuteEnqueueFilter()) result = run_estimator(model_fn, input_fn) return result
def main(_): bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) input_files = get_input_files_from_flags(FLAGS) train_config = TrainConfigEx.from_flags(FLAGS) show_input_files(input_files) model_fn = loss_diff_prediction_model_online( bert_config=bert_config, train_config=train_config, model_class=BertModel, ) if FLAGS.do_train: input_fn = input_fn_builder_unmasked(input_files=input_files, flags=FLAGS, is_training=False) elif FLAGS.do_eval or FLAGS.do_predict: input_fn = input_fn_builder_unmasked(input_files=input_files, flags=FLAGS, is_training=False) else: raise Exception() run_estimator(model_fn, input_fn)