def main(_): set_level_debug() tf_logging.info("Train horizon") config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) is_training = FLAGS.do_train input_files = get_input_files_from_flags(FLAGS) input_fn = input_fn_builder_unmasked(input_files, FLAGS, is_training) model_fn = model_fn_lm(config, train_config, BertologyFactory(HorizontalAlpha), get_masked_lm_output_albert) return run_estimator(model_fn, input_fn)
def main(_): tf_logging.info("Train albert") config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) is_training = FLAGS.do_train input_files = get_input_files_from_flags(FLAGS) input_fn = input_fn_builder_unmasked(input_files, FLAGS, is_training) model_fn = model_fn_lm(config, train_config, Albert.factory, get_masked_lm_output_albert) return run_estimator(model_fn, input_fn)
def main(_): tf_logging.info("Train topic_vector") config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) is_training = FLAGS.do_train input_files = get_input_files_from_flags(FLAGS) input_fn = input_fn_topic_fn(input_files, FLAGS, is_training) model_fn = model_fn_lm(config, train_config, TopicVectorBert.factory, get_masked_lm_output, True) return run_estimator(model_fn, input_fn)
def main(_): set_level_debug() tf_logging.info("Train reshape bert") config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) is_training = FLAGS.do_train input_files = get_input_files_from_flags(FLAGS) input_fn = input_fn_builder_unmasked(input_files, FLAGS, is_training) model_fn = model_fn_lm(config, train_config, ReshapeBertModel) return run_estimator(model_fn, input_fn)
def main(_): tf_logging.info("Train BertModelWithLabel") config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) is_training = FLAGS.do_train input_files = get_input_files_from_flags(FLAGS) input_fn = input_fn_builder_classification(input_files, FLAGS.max_seq_length, is_training, FLAGS, repeat_for_eval=True) model_fn = model_fn_lm(config, train_config, BertModelWithLabel, get_masked_lm_output_fn=get_masked_lm_output, feed_feature=True) return run_estimator(model_fn, input_fn)
def main(_): tf_logging.info("Train MLM") config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) is_training = FLAGS.do_train input_files = [] for input_pattern in FLAGS.input_file.split(","): input_files.extend(tf.io.gfile.glob(input_pattern)) input_fn = input_fn_builder_masked2(input_files, FLAGS, is_training) model_fn = model_fn_lm(config, train_config, BertModel) run_estimator(model_fn, input_fn)
def main(_): tf_logging.info("Train MLM with BERT like") config = JsonConfig.from_json_file(FLAGS.model_config_file) train_config = TrainConfigEx.from_flags(FLAGS) is_training = FLAGS.do_train input_files = get_input_files_from_flags(FLAGS) input_fn = input_fn_builder_unmasked_alt_emb2(input_files, FLAGS, is_training) model_fn = model_fn_lm(config, train_config, BertModel, get_masked_lm_output) run_estimator(model_fn, input_fn)
def lm_pretrain(input_files): bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) tf_logging.info("*** Input Files ***") for idx, input_file in enumerate(input_files): tf_logging.info(" %s" % input_file) if idx > 10: break if FLAGS.do_predict: seed = 0 else: seed = None tf_logging.info("Total of %d files" % len(input_files)) tpu_cluster_resolver = None if FLAGS.use_tpu and FLAGS.tpu_name: tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) config = tf.compat.v1.ConfigProto(allow_soft_placement=False, ) is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2 run_config = tf.compat.v1.estimator.tpu.RunConfig( cluster=tpu_cluster_resolver, master=FLAGS.master, model_dir=FLAGS.output_dir, save_checkpoints_steps=FLAGS.save_checkpoints_steps, keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours, keep_checkpoint_max=FLAGS.keep_checkpoint_max, session_config=config, tf_random_seed=seed, tpu_config=tf.compat.v1.estimator.tpu.TPUConfig( iterations_per_loop=FLAGS.iterations_per_loop, num_shards=FLAGS.num_tpu_cores, per_host_input_for_training=is_per_host)) TASK_LM = 0 TASK_TLM = 1 TASK_DICT_LM = 2 TASK_DICT_LM_VBATCH = 3 task = TASK_LM if FLAGS.target_lm: task = TASK_TLM elif FLAGS.dict_lm: task = TASK_DICT_LM elif FLAGS.dict_lm_vbatch: task = TASK_DICT_LM_VBATCH train_config = TrainConfigEx.from_flags(FLAGS) if task == TASK_LM: tf_logging.info("Running LM") if FLAGS.fixed_mask: input_fn_builder = input_fn_builder_masked else: input_fn_builder = input_fn_builder_unmasked model_fn = model_fn_lm( model_config=bert_config, train_config=train_config, model_class=BertModel, ) elif task == TASK_TLM: tf_logging.info("Running TLM") model_config = JsonConfig.from_json_file(FLAGS.model_config_file) target_model_config = bert_config if model_config.compare_attrib_value_safe("not_twin", True): target_model_config = model_config input_fn_builder = input_fn_builder_unmasked if FLAGS.modeling == "nli_ex": priority_model = get_nli_ex_model_segmented elif FLAGS.modeling == "tlm2": priority_model = partial(tlm2, target_model_config, FLAGS.use_tpu) elif FLAGS.modeling == "tlm_hard": priority_model = partial(tlm_prefer_hard, target_model_config, FLAGS.use_tpu) elif FLAGS.modeling == "BLC": priority_model = brutal_loss_compare input_fn_builder = input_fn_builder_blc elif FLAGS.modeling == "BLC_beta": priority_model = blc_beta input_fn_builder = input_fn_builder_blc else: raise Exception() model_fn = model_fn_target_masking( bert_config=bert_config, train_config=train_config, target_model_config=model_config, model_class=BertModel, priority_model=priority_model, ) elif task == TASK_DICT_LM: tf_logging.info("Running Dict LM") dbert_config = modeling.BertConfig.from_json_file( FLAGS.model_config_file) input_fn_builder = input_fn_builder_dict model_fn = model_fn_dict_reader( bert_config=bert_config, dbert_config=dbert_config, train_config=train_config, logging=tf_logging, model_class=DictReaderModel, dict_run_config=DictRunConfig.from_flags(FLAGS), ) elif task == TASK_DICT_LM_VBATCH: tf_logging.info("Running Dict LM with virtual batch_size") ssdr_config = JsonConfig.from_json_file(FLAGS.model_config_file) if FLAGS.modeling == "mockup": input_fn_builder = input_fn_builder_unmasked model_fn = ssdr_model_fn.model_fn_apr_lm( bert_config=bert_config, ssdr_config=ssdr_config, train_config=train_config, dict_run_config=DictRunConfig.from_flags(FLAGS), ) elif FLAGS.modeling == "debug": tf_logging.info("Running Debugging") input_fn_builder = ssdr_model_fn.input_fn_builder # input_fn_builder = input_fn_builder_unmasked model_fn = ssdr_model_fn.model_fn_apr_debug( bert_config=bert_config, ssdr_config=ssdr_config, train_config=train_config, logging=tf_logging, model_name="APR", dict_run_config=DictRunConfig.from_flags(FLAGS), ) elif FLAGS.modeling == "debug2": tf_logging.info("Running Debugging2") input_fn_builder = ssdr_model_fn.input_fn_builder # input_fn_builder = input_fn_builder_unmasked model_fn = ssdr_model_fn.model_fn_apr_debug( bert_config=bert_config, ssdr_config=ssdr_config, train_config=train_config, logging=tf_logging, model_name="BERT", dict_run_config=DictRunConfig.from_flags(FLAGS), ) else: input_fn_builder = ssdr_model_fn.input_fn_builder model_fn = ssdr_model_fn.model_fn_dict_reader( bert_config=bert_config, ssdr_config=ssdr_config, train_config=train_config, logging=tf_logging, model_class=SSDR, dict_run_config=DictRunConfig.from_flags(FLAGS), ) # If TPU is not available, this will fall back to normal Estimator on CPU # or GPU. estimator = tf.compat.v1.estimator.tpu.TPUEstimator( use_tpu=FLAGS.use_tpu, model_fn=model_fn, config=run_config, train_batch_size=FLAGS.train_batch_size, eval_batch_size=FLAGS.eval_batch_size, predict_batch_size=FLAGS.eval_batch_size, ) if FLAGS.do_train: tf_logging.info("***** Running training *****") tf_logging.info(" Batch size = %d", FLAGS.train_batch_size) train_input_fn = input_fn_builder(input_files=input_files, flags=FLAGS, is_training=True) estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) if FLAGS.do_eval: tf_logging.info("***** Running evaluation *****") tf_logging.info(" Batch size = %d", FLAGS.eval_batch_size) eval_input_fn = input_fn_builder(input_files=input_files, flags=FLAGS, is_training=False) result = estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.max_eval_steps) output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") with tf.io.gfile.GFile(output_eval_file, "w") as writer: tf_logging.info("***** Eval results *****") for key in sorted(result.keys()): tf_logging.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) return result if FLAGS.do_predict: tf_logging.info("***** Running prediction *****") tf_logging.info(" Batch size = %d", FLAGS.eval_batch_size) predict_input_fn = input_fn_builder( input_files=input_files, flags=FLAGS, is_training=False, ) result = estimator.predict(input_fn=predict_input_fn, yield_single_examples=False) tf_logging.info("***** Pickling.. *****") pickle.dump(list(result), open(FLAGS.out_file, "wb"))