def decode_Beam(FLAGS): # If in decode mode, set batch_size = beam_size # Reason: in decode mode, we decode one example at a time. # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses. #if FLAGS.mode == 'decode': # FLAGS.batch_size = FLAGS.beam_size # If single_pass=True, check we're in decode mode #if FLAGS.single_pass and FLAGS.mode != 'decode': # raise Exception("The single_pass flag should only be True in decode mode") vocab_in, vocab_out = data.load_dict_data(FLAGS) FLAGS_batcher = config.retype_FLAGS() FLAGS_decode = FLAGS_batcher._asdict() FLAGS_decode["max_dec_steps"] = 1 FLAGS_decode["mode"] = "decode" FLAGS_decode = config.generate_nametuple(FLAGS_decode) # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries batcher = Batcher(FLAGS.data_path, vocab_in,vocab_out, FLAGS_batcher, data_file=FLAGS.test_name) model = SummarizationModel(FLAGS_decode, vocab_in,vocab_out,batcher) decoder = BeamSearchDecoder(model, batcher, vocab_out) decoder.decode()
def train_with_eval(FLAGS): FLAGS = config.retype_FLAGS() Bert_model, validate_model = create_train_eval_model(FLAGS) checkpoint_basename = os.path.join(FLAGS.output_dir, "Bert-Classify") logging.info(checkpoint_basename) Bert_model.save_model(checkpoint_basename) Best_acc = eval_acc(FLAGS,validate_model,Bert_model) bestDevModel = tf.train.get_checkpoint_state(FLAGS.output_dir).model_checkpoint_path start_step = Bert_model.load_specific_variable(Bert_model.global_step) for step in range(start_step,Bert_model.num_train_steps): batch =Bert_model.batcher.next_batch() if batch==None: bestDevModel, Best_acc, acc = greedy_model_save(bestDevModel, checkpoint_basename, Best_acc, Bert_model, validate_model, FLAGS) logging.info("Finish epoch: {}".format(Bert_model.batcher.c_epoch)) logging.info("ACC {} Best_ACC: {}\n\n".format(acc, Best_acc)) if Bert_model.batcher.c_epoch>=FLAGS.num_train_epochs: break continue results = Bert_model.run_train_step(batch) if step%100==0: logging.info("step {} loss: {}\n".format(step,results["loss"])) if step%FLAGS.save_checkpoints_steps==0 and step!=0: bestDevModel, Best_acc, acc = greedy_model_save(bestDevModel, checkpoint_basename, Best_acc, Bert_model, validate_model, FLAGS) logging.info("ACC {} Best_ACC: {}\n\n".format(acc, Best_acc)) Bert_model.load_specific_model(bestDevModel) Bert_model.save_model(bestDevModel, False)
def create_decode_model(FLAGS, vocab_in,vocab_out): batcher = Batcher(FLAGS.data_path, vocab_in, vocab_out, FLAGS, data_file=FLAGS.qq_name) import eval FLAGS_decode = config.retype_FLAGS()._asdict() FLAGS_decode["max_dec_steps"] = 1 FLAGS_decode["mode"] = "decode" FLAGS_decode = config.generate_nametuple(FLAGS_decode) model = SummarizationModel(FLAGS_decode, vocab_in, vocab_out, batcher) #model.graph.as_default() decoder = eval.EvalDecoder(model, batcher, vocab_out) return decoder
def decode_multi(FLAGS): vocab_in, vocab_out = data.load_dict_data(FLAGS) batcher = Batcher(FLAGS.data_path, vocab_in, vocab_out, FLAGS, data_file=FLAGS.test_name,shuffle=False) import eval FLAGS_decode = config.retype_FLAGS()._asdict() FLAGS_decode["max_dec_steps"] = 1 FLAGS_decode = config.generate_nametuple(FLAGS_decode) model = SummarizationModel(FLAGS_decode, vocab_in, vocab_out, batcher) decoder = eval.EvalDecoder(model, batcher, vocab_out) time_start = time.time() decoder.pair_wise_decode() time_end = time.time() print(time_end - time_start)
def main(_): FLAGS = config.retype_FLAGS() Classify_model = model_pools["tagging_classify_model"] bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) task_name1 = "ner" task_name2 = "qicm" FLAGS_tagging = FLAGS._asdict() FLAGS_tagging["train_batch_size"] = int(FLAGS_tagging["train_batch_size"]/2) FLAGS_tagging["mode"] = "dev" FLAGS_tagging = config.generate_nametuple(FLAGS_tagging) FLAGS_classify = FLAGS._asdict() FLAGS_classify["train_batch_size"] = int(FLAGS_classify["train_batch_size"]/2) FLAGS_classify["mode"] = "dev" #pad to equal FLAGS_classify["train_batch_size"]+= FLAGS.train_batch_size - FLAGS_tagging.train_batch_size - FLAGS_classify["train_batch_size"] FLAGS_classify["train_file"] = FLAGS_classify["train_file_multi"] FLAGS_classify["dev_file"] = FLAGS_classify["dev_file_multi"] FLAGS_classify["test_file"] = FLAGS_classify["test_file_multi"] FLAGS_classify = config.generate_nametuple(FLAGS_classify) processor_tagging = processors[task_name1]() processor_classify = processors[task_name2]() tagging_batcher = Batcher(processor_tagging, FLAGS_tagging) classify_batcher = Batcher(processor_classify, FLAGS_classify) # create trainning model Bert_model = Classify_model(bert_config, tagging_batcher,classify_batcher, FLAGS) for step in range(0, Bert_model.num_train_steps): tagging_batch = Bert_model.tagging_batcher.next_batch() classify_batch = Bert_model.classify_batcher.next_batch() batch = Bert_model.classify_batcher.merge_multi_task(tagging_batch, classify_batch) results = Bert_model.run_dev_step(batch)
def deocode_train_eval(FLAGS): FLAGS = config.retype_FLAGS() vocab_in, vocab_out = data.load_dict_data(FLAGS) checkpoint_basename = os.path.join(FLAGS.log_root, "PointerGenerator_model") logging.info(checkpoint_basename) #decoder.decode() #batcher_pair = Batcher(FLAGS.data_path, vocab_in, vocab_out, FLAGS, data_file=FLAGS.qq_name) train_model, dev_model = create_training_model(FLAGS, vocab_in,vocab_out) train_model.save_model(checkpoint_basename) decoder = create_decode_model(FLAGS, vocab_in, vocab_out) best_bleu, best_acc, dev_loss = validation_acc(dev_model, FLAGS) logging.info("bleu_now {}".format(best_bleu)) tmpDevModel = checkpoint_basename + "tmp" bad_valid = 0 bestDevModel = tf.train.get_checkpoint_state(FLAGS.log_root).model_checkpoint_path while True: step = train_model.get_specific_variable(train_model.global_step) if step > FLAGS.max_run_steps: break if FLAGS.qq_loss: decoder._model.create_or_load_recent_model() loss = train_one_epoch_multi_task(train_model, dev_model,decoder, FLAGS) else: loss = train_one_epoch(train_model,dev_model,FLAGS) if np.isnan(loss) or loss<0: logging.info("loss is nan, restore") train_model.load_specific_model(bestDevModel) bleu = -1 acc = -1 else: train_model.save_model(tmpDevModel,False) bleu, acc, dev_loss = validation_acc(dev_model, FLAGS) if acc>=best_acc: lr = train_model.get_specific_variable(train_model.learning_rate) logging.info("save new best model, learning rate {} step {}".format(lr,step)) train_model.save_model(checkpoint_basename) bad_valid = 0 best_bleu = bleu best_acc = acc bestDevModel = tf.train.get_checkpoint_state(FLAGS.log_root).model_checkpoint_path else: if FLAGS.badvalid==0: continue bad_valid += 1 logging.info("bad valid {} compared with bestDevModel {} bleu {} acc {}".format(bad_valid,bestDevModel,best_bleu,best_acc)) lr = train_model.get_specific_variable(train_model.learning_rate) logging.info("current learning rate {}".format(lr)) if bad_valid>=FLAGS.badvalid: logging.info("restore model {} for {}".format(step,bestDevModel)) train_model.load_specific_model(bestDevModel) train_model.run_decay_lr() train_model.save_model(checkpoint_basename) bestDevModel = tf.train.get_checkpoint_state(FLAGS.log_root).model_checkpoint_path bad_valid = 0 if lr<0.001: logging.info("lr = {}, stop".format(lr)) break #decoder._model.create_or_load_recent_model() #decoder.decode() train_model.load_specific_model(bestDevModel) train_model.save_model(bestDevModel,False)
def train_with_eval(FLAGS): FLAGS = config.retype_FLAGS() logging.info("hidden_dim:" + str(FLAGS.hidden_dim)) logging.info("emb_dim:" + str(FLAGS.emb_dim)) logging.info("batch_size:" + str(FLAGS.batch_size)) logging.info("max_enc_steps:" + str(FLAGS.max_enc_steps)) logging.info("max_dec_steps:" + str(FLAGS.max_dec_steps)) logging.info("learning rate:" + str(FLAGS.lr)) # load dictionary vocab_in, vocab_out = data.load_dict_data(FLAGS) checkpoint_basename = os.path.join(FLAGS.log_root, "PointerGenerator_model") logging.info(checkpoint_basename) logging.info("creating model...") train_model, dev_model = create_training_model(FLAGS, vocab_in,vocab_out) train_model.save_model(checkpoint_basename) best_bleu, best_acc, dev_loss = validation_acc(dev_model, FLAGS) logging.info("bleu_now {}".format(best_bleu)) tmpDevModel = checkpoint_basename+"tmp" #train_model.save_model(tmpDevModel, False) bad_valid = 0 bestDevModel = tf.train.get_checkpoint_state(FLAGS.log_root).model_checkpoint_path while True: step = train_model.get_specific_variable(train_model.global_step) if step > FLAGS.max_run_steps: break loss = train_one_epoch(train_model, dev_model, FLAGS) if np.isnan(loss) or loss<0: logging.info("loss is nan, restore") train_model.load_specific_model(bestDevModel) bleu = -1 acc = -1 else: train_model.save_model(tmpDevModel,False) bleu, acc, dev_loss = validation_acc(dev_model, FLAGS) if acc>=best_acc: lr = train_model.get_specific_variable(train_model.learning_rate) logging.info("save new best model, learning rate {} step {}".format(lr,step)) train_model.save_model(checkpoint_basename) bad_valid = 0 best_bleu = bleu best_acc = acc bestDevModel = tf.train.get_checkpoint_state(FLAGS.log_root).model_checkpoint_path else: if FLAGS.badvalid==0: continue bad_valid += 1 logging.info("bad valid {} compared with bestDevModel {} bleu {} acc {}".format(bad_valid,bestDevModel,best_bleu,best_acc)) lr = train_model.get_specific_variable(train_model.learning_rate) logging.info("current learning rate {}".format(lr)) if bad_valid>=FLAGS.badvalid: logging.info("restore model {} for {}".format(step,bestDevModel)) train_model.load_specific_model(bestDevModel) train_model.run_decay_lr() train_model.save_model(checkpoint_basename) bestDevModel = tf.train.get_checkpoint_state(FLAGS.log_root).model_checkpoint_path bad_valid = 0 if lr<0.001: logging.info("lr = {}, stop".format(lr)) break train_model.load_specific_model(bestDevModel) train_model.save_model(bestDevModel,False)