Пример #1
0
def decode_Beam(FLAGS):
    # If in decode mode, set batch_size = beam_size
    # Reason: in decode mode, we decode one example at a time.
    # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses.
    #if FLAGS.mode == 'decode':
    #    FLAGS.batch_size = FLAGS.beam_size

    # If single_pass=True, check we're in decode mode
    #if FLAGS.single_pass and FLAGS.mode != 'decode':
    #    raise Exception("The single_pass flag should only be True in decode mode")


    vocab_in, vocab_out = data.load_dict_data(FLAGS)

    FLAGS_batcher = config.retype_FLAGS()

    FLAGS_decode = FLAGS_batcher._asdict()
    FLAGS_decode["max_dec_steps"] = 1
    FLAGS_decode["mode"] = "decode"
    FLAGS_decode = config.generate_nametuple(FLAGS_decode)
    # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries
    batcher = Batcher(FLAGS.data_path, vocab_in,vocab_out, FLAGS_batcher,  data_file=FLAGS.test_name)

    model = SummarizationModel(FLAGS_decode, vocab_in,vocab_out,batcher)
    decoder = BeamSearchDecoder(model, batcher, vocab_out)
    decoder.decode()
Пример #2
0
def train_with_eval(FLAGS):
    FLAGS = config.retype_FLAGS()
    Bert_model, validate_model = create_train_eval_model(FLAGS)

    checkpoint_basename = os.path.join(FLAGS.output_dir, "Bert-Classify")
    logging.info(checkpoint_basename)
    Bert_model.save_model(checkpoint_basename)
    Best_acc = eval_acc(FLAGS,validate_model,Bert_model)
    bestDevModel = tf.train.get_checkpoint_state(FLAGS.output_dir).model_checkpoint_path

    start_step = Bert_model.load_specific_variable(Bert_model.global_step)
    for step in range(start_step,Bert_model.num_train_steps):
        batch =Bert_model.batcher.next_batch()
        if batch==None:
            bestDevModel, Best_acc, acc = greedy_model_save(bestDevModel, checkpoint_basename, Best_acc, Bert_model,
                                                            validate_model, FLAGS)
            logging.info("Finish epoch: {}".format(Bert_model.batcher.c_epoch))
            logging.info("ACC {} Best_ACC: {}\n\n".format(acc, Best_acc))
            if Bert_model.batcher.c_epoch>=FLAGS.num_train_epochs:
                break
            continue

        results = Bert_model.run_train_step(batch)

        if step%100==0:
            logging.info("step {} loss: {}\n".format(step,results["loss"]))

        if step%FLAGS.save_checkpoints_steps==0 and step!=0:
            bestDevModel, Best_acc, acc = greedy_model_save(bestDevModel, checkpoint_basename, Best_acc, Bert_model, validate_model, FLAGS)
            logging.info("ACC {} Best_ACC: {}\n\n".format(acc, Best_acc))

    Bert_model.load_specific_model(bestDevModel)
    Bert_model.save_model(bestDevModel, False)
Пример #3
0
def create_decode_model(FLAGS, vocab_in,vocab_out):
    batcher = Batcher(FLAGS.data_path, vocab_in, vocab_out, FLAGS, data_file=FLAGS.qq_name)
    import eval
    FLAGS_decode = config.retype_FLAGS()._asdict()
    FLAGS_decode["max_dec_steps"] = 1
    FLAGS_decode["mode"] = "decode"
    FLAGS_decode = config.generate_nametuple(FLAGS_decode)
    model = SummarizationModel(FLAGS_decode, vocab_in, vocab_out, batcher)
    #model.graph.as_default()
    decoder = eval.EvalDecoder(model, batcher, vocab_out)

    return decoder
Пример #4
0
def decode_multi(FLAGS):
    vocab_in, vocab_out = data.load_dict_data(FLAGS)
    batcher = Batcher(FLAGS.data_path, vocab_in, vocab_out, FLAGS, data_file=FLAGS.test_name,shuffle=False)
    import eval
    FLAGS_decode = config.retype_FLAGS()._asdict()
    FLAGS_decode["max_dec_steps"] = 1
    FLAGS_decode = config.generate_nametuple(FLAGS_decode)
    model = SummarizationModel(FLAGS_decode, vocab_in, vocab_out, batcher)
    decoder = eval.EvalDecoder(model, batcher, vocab_out)

    time_start = time.time()
    decoder.pair_wise_decode()
    time_end = time.time()
    print(time_end - time_start)
Пример #5
0
def main(_):
    FLAGS = config.retype_FLAGS()
    Classify_model = model_pools["tagging_classify_model"]
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    task_name1 = "ner"
    task_name2 = "qicm"

    FLAGS_tagging = FLAGS._asdict()
    FLAGS_tagging["train_batch_size"] = int(FLAGS_tagging["train_batch_size"]/2)
    FLAGS_tagging["mode"] = "dev"
    FLAGS_tagging = config.generate_nametuple(FLAGS_tagging)

    FLAGS_classify = FLAGS._asdict()
    FLAGS_classify["train_batch_size"] = int(FLAGS_classify["train_batch_size"]/2)
    FLAGS_classify["mode"] = "dev"
    #pad to equal
    FLAGS_classify["train_batch_size"]+= FLAGS.train_batch_size - FLAGS_tagging.train_batch_size - FLAGS_classify["train_batch_size"]

    FLAGS_classify["train_file"] = FLAGS_classify["train_file_multi"]
    FLAGS_classify["dev_file"] = FLAGS_classify["dev_file_multi"]
    FLAGS_classify["test_file"] = FLAGS_classify["test_file_multi"]
    FLAGS_classify = config.generate_nametuple(FLAGS_classify)

    processor_tagging = processors[task_name1]()
    processor_classify = processors[task_name2]()

    tagging_batcher = Batcher(processor_tagging, FLAGS_tagging)
    classify_batcher = Batcher(processor_classify, FLAGS_classify)

    # create trainning model
    Bert_model = Classify_model(bert_config, tagging_batcher,classify_batcher, FLAGS)

    for step in range(0, Bert_model.num_train_steps):
        tagging_batch = Bert_model.tagging_batcher.next_batch()
        classify_batch = Bert_model.classify_batcher.next_batch()

        batch = Bert_model.classify_batcher.merge_multi_task(tagging_batch, classify_batch)

        results = Bert_model.run_dev_step(batch)
Пример #6
0
def deocode_train_eval(FLAGS):
    FLAGS = config.retype_FLAGS()
    vocab_in, vocab_out = data.load_dict_data(FLAGS)
    checkpoint_basename = os.path.join(FLAGS.log_root, "PointerGenerator_model")
    logging.info(checkpoint_basename)
    #decoder.decode()
    #batcher_pair = Batcher(FLAGS.data_path, vocab_in, vocab_out, FLAGS, data_file=FLAGS.qq_name)

    train_model, dev_model = create_training_model(FLAGS, vocab_in,vocab_out)
    train_model.save_model(checkpoint_basename)

    decoder = create_decode_model(FLAGS, vocab_in, vocab_out)

    best_bleu, best_acc, dev_loss = validation_acc(dev_model, FLAGS)
    logging.info("bleu_now {}".format(best_bleu))

    tmpDevModel = checkpoint_basename + "tmp"
    bad_valid = 0
    bestDevModel = tf.train.get_checkpoint_state(FLAGS.log_root).model_checkpoint_path

    while True:
        step = train_model.get_specific_variable(train_model.global_step)
        if step > FLAGS.max_run_steps:
            break


        if FLAGS.qq_loss:
            decoder._model.create_or_load_recent_model()
            loss = train_one_epoch_multi_task(train_model, dev_model,decoder, FLAGS)
        else:
            loss = train_one_epoch(train_model,dev_model,FLAGS)

        if np.isnan(loss) or loss<0:
            logging.info("loss is nan, restore")
            train_model.load_specific_model(bestDevModel)
            bleu = -1
            acc = -1
        else:
            train_model.save_model(tmpDevModel,False)
            bleu, acc, dev_loss = validation_acc(dev_model, FLAGS)


        if acc>=best_acc:
            lr = train_model.get_specific_variable(train_model.learning_rate)
            logging.info("save new best model, learning rate {} step {}".format(lr,step))
            train_model.save_model(checkpoint_basename)
            bad_valid = 0

            best_bleu = bleu
            best_acc = acc
            bestDevModel = tf.train.get_checkpoint_state(FLAGS.log_root).model_checkpoint_path
        else:
            if FLAGS.badvalid==0:
                continue
            bad_valid += 1
            logging.info("bad valid {} compared with bestDevModel {} bleu {} acc {}".format(bad_valid,bestDevModel,best_bleu,best_acc))
            lr = train_model.get_specific_variable(train_model.learning_rate)
            logging.info("current learning rate {}".format(lr))
            if bad_valid>=FLAGS.badvalid:
                logging.info("restore model {} for {}".format(step,bestDevModel))
                train_model.load_specific_model(bestDevModel)
                train_model.run_decay_lr()
                train_model.save_model(checkpoint_basename)
                bestDevModel = tf.train.get_checkpoint_state(FLAGS.log_root).model_checkpoint_path
                bad_valid = 0
                if lr<0.001:
                    logging.info("lr = {}, stop".format(lr))
                    break
        #decoder._model.create_or_load_recent_model()
        #decoder.decode()

    train_model.load_specific_model(bestDevModel)
    train_model.save_model(bestDevModel,False)
Пример #7
0
def train_with_eval(FLAGS):
    FLAGS = config.retype_FLAGS()

    logging.info("hidden_dim:" + str(FLAGS.hidden_dim))
    logging.info("emb_dim:" + str(FLAGS.emb_dim))
    logging.info("batch_size:" + str(FLAGS.batch_size))
    logging.info("max_enc_steps:" + str(FLAGS.max_enc_steps))
    logging.info("max_dec_steps:" + str(FLAGS.max_dec_steps))
    logging.info("learning rate:" + str(FLAGS.lr))

    # load dictionary
    vocab_in, vocab_out = data.load_dict_data(FLAGS)

    checkpoint_basename = os.path.join(FLAGS.log_root, "PointerGenerator_model")
    logging.info(checkpoint_basename)

    logging.info("creating model...")
    train_model, dev_model = create_training_model(FLAGS, vocab_in,vocab_out)
    train_model.save_model(checkpoint_basename)

    best_bleu, best_acc, dev_loss = validation_acc(dev_model, FLAGS)
    logging.info("bleu_now {}".format(best_bleu))

    tmpDevModel = checkpoint_basename+"tmp"
    #train_model.save_model(tmpDevModel, False)
    bad_valid = 0
    bestDevModel = tf.train.get_checkpoint_state(FLAGS.log_root).model_checkpoint_path
    while True:
        step = train_model.get_specific_variable(train_model.global_step)
        if step > FLAGS.max_run_steps:
            break

        loss = train_one_epoch(train_model, dev_model, FLAGS)
        if np.isnan(loss) or loss<0:
            logging.info("loss is nan, restore")
            train_model.load_specific_model(bestDevModel)
            bleu = -1
            acc = -1
        else:
            train_model.save_model(tmpDevModel,False)
            bleu, acc, dev_loss = validation_acc(dev_model, FLAGS)


        if acc>=best_acc:
            lr = train_model.get_specific_variable(train_model.learning_rate)
            logging.info("save new best model, learning rate {} step {}".format(lr,step))
            train_model.save_model(checkpoint_basename)
            bad_valid = 0

            best_bleu = bleu
            best_acc = acc
            bestDevModel = tf.train.get_checkpoint_state(FLAGS.log_root).model_checkpoint_path
        else:
            if FLAGS.badvalid==0:
                continue
            bad_valid += 1
            logging.info("bad valid {} compared with bestDevModel {} bleu {} acc {}".format(bad_valid,bestDevModel,best_bleu,best_acc))
            lr = train_model.get_specific_variable(train_model.learning_rate)
            logging.info("current learning rate {}".format(lr))
            if bad_valid>=FLAGS.badvalid:
                logging.info("restore model {} for {}".format(step,bestDevModel))
                train_model.load_specific_model(bestDevModel)
                train_model.run_decay_lr()
                train_model.save_model(checkpoint_basename)
                bestDevModel = tf.train.get_checkpoint_state(FLAGS.log_root).model_checkpoint_path
                bad_valid = 0
                if lr<0.001:
                    logging.info("lr = {}, stop".format(lr))
                    break
    train_model.load_specific_model(bestDevModel)
    train_model.save_model(bestDevModel,False)