Exemplo n.º 1
0
def get_dataset(device):
    # with open(params.api_dir, "rb") as fh:
    #     api = pkl.load(fh, encoding='latin1')

    api = SWDADialogCorpus(params.data_dir)

    dial_corpus = api.get_dialog_corpus()

    train_dial, test_dial = dial_corpus.get("train"), dial_corpus.get("test")

    # convert to numeric input outputs
    train_loader = SWDADataLoader("Train",
                                  train_dial,
                                  params.max_utt_len,
                                  params.max_dialog_len,
                                  device=device)
    valid_loader = test_loader = SWDADataLoader("Test",
                                                test_dial,
                                                params.max_utt_len,
                                                params.max_dialog_len,
                                                device=device)
    if api.word2vec is not None:
        return train_loader, valid_loader, test_loader, np.array(api.word2vec)
    else:
        return train_loader, valid_loader, test_loader, None
def main():
    # config for training
    config = Config()
    config.use_bow = False

    # config for validation
    valid_config = Config()
    valid_config.keep_prob = 1.0
    valid_config.dec_keep_prob = 1.0
    valid_config.batch_size = 60
    valid_config.use_bow = False

    # configuration for testing
    test_config = Config()
    test_config.keep_prob = 1.0
    test_config.dec_keep_prob = 1.0
    test_config.batch_size = 1
    test_config.use_bow = False

    pp(config)

    # get data set
    api = SWDADialogCorpus(FLAGS.data_dir,
                           word2vec=FLAGS.word2vec_path,
                           word2vec_dim=config.embed_size)
    dial_corpus = api.get_dialog_corpus()
    meta_corpus = api.get_meta_corpus()

    train_meta, valid_meta, test_meta = meta_corpus.get(
        "train"), meta_corpus.get("valid"), meta_corpus.get("test")
    train_dial, valid_dial, test_dial = dial_corpus.get(
        "train"), dial_corpus.get("valid"), dial_corpus.get("test")

    # convert to numeric input outputs that fits into TF models
    train_feed = SWDADataLoader("Train", train_dial, train_meta, config)
    valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config)
    test_feed = SWDADataLoader("Test", test_dial, test_meta, config)

    # begin training
    sess_config = tf.ConfigProto(log_device_placement=False,
                                 allow_soft_placement=True)
    # sess_config.gpu_options.allow_growth = True
    sess_config.gpu_options.per_process_gpu_memory_fraction = 0.45
    with tf.Session(config=sess_config) as sess:
        initializer = tf.random_uniform_initializer(-1.0 * config.init_w,
                                                    config.init_w)
        scope = "model"
        with tf.variable_scope(scope, reuse=None, initializer=initializer):
            model = KgRnnCVAE(sess,
                              config,
                              api,
                              log_dir=None if FLAGS.forward_only else log_dir,
                              forward=False,
                              scope=scope)
        with tf.variable_scope(scope, reuse=True, initializer=initializer):
            valid_model = KgRnnCVAE(sess,
                                    valid_config,
                                    api,
                                    log_dir=None,
                                    forward=False,
                                    scope=scope)
        with tf.variable_scope(scope, reuse=True, initializer=initializer):
            test_model = KgRnnCVAE(sess,
                                   test_config,
                                   api,
                                   log_dir=None,
                                   forward=True,
                                   scope=scope)

        test_model.prepare_mul_ref()

        logger.info("Created computation graphs")
        if api.word2vec is not None and not FLAGS.forward_only:
            logger.info("Loaded word2vec")
            sess.run(model.embedding.assign(np.array(api.word2vec)))

        # write config to a file for logging
        if not FLAGS.forward_only:
            with open(os.path.join(log_dir, "run.log"), "wb") as f:
                f.write(pp(config, output=False))

        # create a folder by force
        ckp_dir = os.path.join(log_dir, "checkpoints")
        if not os.path.exists(ckp_dir):
            os.mkdir(ckp_dir)

        ckpt = tf.train.get_checkpoint_state(ckp_dir)
        logger.info("Created models with fresh parameters.")
        sess.run(tf.global_variables_initializer())

        if ckpt:
            logger.info("Reading dm models parameters from %s" %
                        ckpt.model_checkpoint_path)
            model.saver.restore(sess, ckpt.model_checkpoint_path)

        if not FLAGS.forward_only:
            dm_checkpoint_path = os.path.join(
                ckp_dir, model.__class__.__name__ + ".ckpt")
            global_t = 1
            patience = 10  # wait for at least 10 epoch before stop
            dev_loss_threshold = np.inf
            best_dev_loss = np.inf
            for epoch in range(config.max_epoch):
                logger.info(">> Epoch %d with lr %f" %
                            (epoch,
                             sess.run(model.learning_rate_cyc,
                                      {model.global_t: global_t})))

                # begin training
                if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch:
                    train_feed.epoch_init(config.batch_size,
                                          config.backward_size,
                                          config.step_size,
                                          shuffle=True)
                global_t, train_loss = model.train(
                    global_t,
                    sess,
                    train_feed,
                    update_limit=config.update_limit)

                # begin validation
                logger.record_tabular("Epoch", epoch)
                logger.record_tabular("Mode", "Val")
                valid_feed.epoch_init(valid_config.batch_size,
                                      valid_config.backward_size,
                                      valid_config.step_size,
                                      shuffle=False,
                                      intra_shuffle=False)
                valid_loss = valid_model.valid("ELBO_VALID", sess, valid_feed)

                logger.record_tabular("Epoch", epoch)
                logger.record_tabular("Mode", "Test")
                test_feed.epoch_init(valid_config.batch_size,
                                     valid_config.backward_size,
                                     valid_config.step_size,
                                     shuffle=False,
                                     intra_shuffle=False)
                valid_model.valid("ELBO_TEST", sess, test_feed)

                # test_feed.epoch_init(test_config.batch_size, test_config.backward_size,
                #                      test_config.step_size, shuffle=True, intra_shuffle=False)
                # test_model.test_mul_ref(sess, test_feed, num_batch=5)

                done_epoch = epoch + 1
                # only save a models if the dev loss is smaller
                # Decrease learning rate if no improvement was seen over last 3 times.
                if config.op == "sgd" and done_epoch > config.lr_hold:
                    sess.run(model.learning_rate_decay_op)

                if valid_loss < best_dev_loss:
                    if valid_loss <= dev_loss_threshold * config.improve_threshold:
                        patience = max(patience,
                                       done_epoch * config.patient_increase)
                        dev_loss_threshold = valid_loss

                    # still save the best train model
                    if FLAGS.save_model:
                        logger.info("Save model!!")
                        model.saver.save(sess, dm_checkpoint_path)
                    best_dev_loss = valid_loss

                    if (epoch % 3) == 2:
                        tmp_model_path = os.path.join(
                            ckp_dir,
                            model.__class__.__name__ + str(epoch) + ".ckpt")
                        model.saver.save(sess, tmp_model_path)

                if config.early_stop and patience <= done_epoch:
                    logger.info("!!Early stop due to run out of patience!!")
                    break
            logger.info("Best validation loss %f" % best_dev_loss)
            logger.info("Done training")
        else:
            # begin validation
            # begin validation
            valid_feed.epoch_init(valid_config.batch_size,
                                  valid_config.backward_size,
                                  valid_config.step_size,
                                  shuffle=False,
                                  intra_shuffle=False)
            valid_model.valid("ELBO_VALID", sess, valid_feed)

            test_feed.epoch_init(valid_config.batch_size,
                                 valid_config.backward_size,
                                 valid_config.step_size,
                                 shuffle=False,
                                 intra_shuffle=False)
            valid_model.valid("ELBO_TEST", sess, test_feed)

            dest_f = open(os.path.join(log_dir, "test.txt"), "wb")
            test_feed.epoch_init(test_config.batch_size,
                                 test_config.backward_size,
                                 test_config.step_size,
                                 shuffle=False,
                                 intra_shuffle=False)
            test_model.test_mul_ref(sess,
                                    test_feed,
                                    num_batch=None,
                                    repeat=5,
                                    dest=dest_f)
            dest_f.close()
Exemplo n.º 3
0
def main():

    # config for training
    config = Config()
    print("Normal train config:")
    pp(config)

    valid_config = Config()
    valid_config.dropout = 0
    valid_config.batch_size = 20

    # config for test
    test_config = Config()
    test_config.dropout = 0
    test_config.batch_size = 1

    with_sentiment = config.with_sentiment

    ###############################################################################
    # Load data
    ###############################################################################
    # sentiment data path:  ../ final_data / poem_with_sentiment.txt
    # 该path必须命令行显示输入LoadPoem,因为defaultNonehjk
    # 处理pretrain数据和完整诗歌数据

    # api = LoadPoem(args.train_data_dir, args.test_data_dir, args.max_vocab_size)
    api = LoadPoem(corpus_path=args.train_data_dir,
                   test_path=args.test_data_dir,
                   max_vocab_cnt=config.max_vocab_cnt,
                   with_sentiment=with_sentiment)

    # 交替训练,准备大数据集
    poem_corpus = api.get_tokenized_poem_corpus(
        type=1 + int(with_sentiment))  # corpus for training and validation
    test_data = api.get_tokenized_test_corpus()  # 测试数据
    # 三个list,每个list中的每一个元素都是 [topic, last_sentence, current_sentence]
    train_poem, valid_poem, test_poem = poem_corpus["train"], poem_corpus[
        "valid"], test_data["test"]

    train_loader = SWDADataLoader("Train", train_poem, config)
    valid_loader = SWDADataLoader("Valid", valid_poem, config)
    test_loader = SWDADataLoader("Test", test_poem, config)

    print("Finish Poem data loading, not pretraining or alignment test")

    if not args.forward_only:
        # LOG #
        log_start_time = str(datetime.now().strftime('%Y%m%d%H%M'))
        if not os.path.isdir('./output'):
            os.makedirs('./output')
        if not os.path.isdir('./output/{}'.format(args.expname)):
            os.makedirs('./output/{}'.format(args.expname))
        if not os.path.isdir('./output/{}/{}'.format(args.expname,
                                                     log_start_time)):
            os.makedirs('./output/{}/{}'.format(args.expname, log_start_time))

        # save arguments
        json.dump(
            vars(args),
            open(
                './output/{}/{}/args.json'.format(args.expname,
                                                  log_start_time), 'w'))

        logger = logging.getLogger(__name__)
        logging.basicConfig(level=logging.DEBUG, format="%(message)s")
        fh = logging.FileHandler("./output/{}/{}/logs.txt".format(
            args.expname, log_start_time))
        # add the handlers to the logger
        logger.addHandler(fh)
        logger.info(vars(args))

        tb_writer = SummaryWriter("./output/{}/{}/tb_logs".format(
            args.expname, log_start_time)) if args.visual else None

        if config.reload_model:
            model = load_model(config.model_name)
        else:
            if args.model == "mCVAE":
                model = CVAE_GMP(config=config, api=api)
            elif args.model == 'CVAE':
                model = CVAE(config=config, api=api)
            else:
                model = Seq2Seq(config=config, api=api)
            if use_cuda:
                model = model.cuda()

        # if corpus.word2vec is not None and args.reload_from<0:
        #     print("Loaded word2vec")
        #     model.embedder.weight.data.copy_(torch.from_numpy(corpus.word2vec))
        #     model.embedder.weight.data[0].fill_(0)

        ###############################################################################
        # Start training
        ###############################################################################
        # model依然是PoemWAE_GMP保持不变,只不过,用这部分数据强制训练其中一个高斯先验分布
        # pretrain = True

        cur_best_score = {
            'min_valid_loss': 100,
            'min_global_itr': 0,
            'min_epoch': 0,
            'min_itr': 0
        }

        train_loader.epoch_init(config.batch_size, shuffle=True)

        # model = load_model(3, 3)
        epoch_id = 0
        global_t = 0
        while epoch_id < config.epochs:

            while True:  # loop through all batches in training data
                # train一个batch
                model, finish_train, loss_records, global_t = \
                    train_process(global_t=global_t, model=model, train_loader=train_loader, config=config, sentiment_data=with_sentiment)
                if finish_train:
                    test_process(model=model,
                                 test_loader=test_loader,
                                 test_config=test_config,
                                 logger=logger)
                    # evaluate_process(model=model, valid_loader=valid_loader, log_start_time=log_start_time, global_t=global_t, epoch=epoch_id, logger=logger, tb_writer=tb_writer, api=api)
                    # save model after each epoch
                    save_model(model=model,
                               epoch=epoch_id,
                               global_t=global_t,
                               log_start_time=log_start_time)
                    logger.info(
                        'Finish epoch %d, current min valid loss: %.4f \
                     correspond epoch: %d  itr: %d \n\n' %
                        (cur_best_score['min_valid_loss'],
                         cur_best_score['min_global_itr'],
                         cur_best_score['min_epoch'],
                         cur_best_score['min_itr']))
                    # 初始化下一个unlabeled data epoch的训练
                    # unlabeled_epoch += 1
                    epoch_id += 1
                    train_loader.epoch_init(config.batch_size, shuffle=True)
                    break
                # elif batch_idx >= start_batch + config.n_batch_every_iter:
                #     print("Finish unlabel epoch %d batch %d to %d" %
                #           (unlabeled_epoch, start_batch, start_batch + config.n_batch_every_iter))
                #     start_batch += config.n_batch_every_iter
                #     break

                # 写一下log
                if global_t % config.log_every == 0:
                    log = 'Epoch id %d: step: %d/%d: ' \
                          % (epoch_id, global_t % train_loader.num_batch, train_loader.num_batch)
                    for loss_name, loss_value in loss_records:
                        if loss_name == 'avg_lead_loss':
                            continue
                        log = log + loss_name + ':%.4f ' % loss_value
                        if args.visual:
                            tb_writer.add_scalar(loss_name, loss_value,
                                                 global_t)
                    logger.info(log)

                # valid
                if global_t % config.valid_every == 0:
                    # test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger)
                    valid_process(
                        global_t=global_t,
                        model=model,
                        valid_loader=valid_loader,
                        valid_config=valid_config,
                        unlabeled_epoch=
                        epoch_id,  # 如果sample_rate_unlabeled不是1,这里要在最后加一个1
                        tb_writer=tb_writer,
                        logger=logger,
                        cur_best_score=cur_best_score)
                # if batch_idx % (train_loader.num_batch // 3) == 0:
                #     test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger)
                if global_t % config.test_every == 0:
                    test_process(model=model,
                                 test_loader=test_loader,
                                 test_config=test_config,
                                 logger=logger)

    # forward_only 测试
    else:
        expname = 'sentInput'
        time = '202101191105'

        model = load_model(
            './output/{}/{}/model_global_t_13596_epoch3.pckl'.format(
                expname, time))
        test_loader.epoch_init(test_config.batch_size, shuffle=False)
        if not os.path.exists('./output/{}/{}/test/'.format(expname, time)):
            os.mkdir('./output/{}/{}/test/'.format(expname, time))
        output_file = [
            open('./output/{}/{}/test/output_0.txt'.format(expname, time),
                 'w'),
            open('./output/{}/{}/test/output_1.txt'.format(expname, time),
                 'w'),
            open('./output/{}/{}/test/output_2.txt'.format(expname, time), 'w')
        ]

        poem_count = 0
        predict_results = {0: [], 1: [], 2: []}
        titles = {0: [], 1: [], 2: []}
        sentiment_result = {0: [], 1: [], 2: []}
        # Get all poem predictions
        while True:
            model.eval()
            batch = test_loader.next_batch_test()  # test data使用专门的batch
            poem_count += 1
            if poem_count % 10 == 0:
                print("Predicted {} poems".format(poem_count))
            if batch is None:
                break
            title_list = batch  # batch size是1,一个batch写一首诗
            title_tensor = to_tensor(title_list)
            # test函数将当前batch对应的这首诗decode出来,记住每次decode的输入context是上一次的结果
            for i in range(3):
                sentiment_label = np.zeros(1, dtype=np.int64)
                sentiment_label[0] = int(i)
                sentiment_label = to_tensor(sentiment_label)
                output_poem, output_tokens = model.test(
                    title_tensor, title_list, sentiment_label=sentiment_label)

                titles[i].append(output_poem.strip().split('\n')[0])
                predict_results[i] += (np.array(output_tokens)[:, :7].tolist())

        # Predict sentiment use the sort net
        from collections import defaultdict
        neg = defaultdict(int)
        neu = defaultdict(int)
        pos = defaultdict(int)
        total = defaultdict(int)
        for i in range(3):
            _, neg[i], neu[i], pos[i] = test_sentiment(predict_results[i])
            total[i] = neg[i] + neu[i] + pos[i]

        for i in range(3):
            print("%d%%\t%d%%\t%d%%" %
                  (neg * 100 / total, neu * 100 / total, pos * 100 / total))

        for i in range(3):
            write_predict_result_to_file(titles[i], predict_results[i],
                                         sentiment_result[i], output_file[i])
            output_file[i].close()

        print("Done testing")
Exemplo n.º 4
0
def main():
    # config for training
    config = Config()

    # config for validation
    valid_config = Config()
    valid_config.keep_prob = 1.0
    valid_config.dec_keep_prob = 1.0
    valid_config.batch_size = 60

    # configuration for testing
    test_config = Config()
    test_config.keep_prob = 1.0
    test_config.dec_keep_prob = 1.0
    test_config.batch_size = 1

    pp(config)

    # get data set
    api = SWDADialogCorpus(FLAGS.data_dir,
                           word2vec=FLAGS.word2vec_path,
                           word2vec_dim=config.embed_size)
    dial_corpus = api.get_dialog_corpus()
    meta_corpus = api.get_meta_corpus()

    train_meta, valid_meta, test_meta = meta_corpus.get(
        "train"), meta_corpus.get("valid"), meta_corpus.get("test")
    train_dial, valid_dial, test_dial = dial_corpus.get(
        "train"), dial_corpus.get("valid"), dial_corpus.get("test")

    # convert to numeric input outputs that fits into TF models
    train_feed = SWDADataLoader("Train", train_dial, train_meta, config)
    valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config)
    test_feed = SWDADataLoader("Test", test_dial, test_meta, config)

    if FLAGS.forward_only or FLAGS.resume:
        log_dir = os.path.join(FLAGS.work_dir, FLAGS.test_path)
    else:
        log_dir = os.path.join(FLAGS.work_dir, "run" + str(int(time.time())))

    # begin training
    if True:
        scope = "model"
        model = KgRnnCVAE(config,
                          api,
                          log_dir=None if FLAGS.forward_only else log_dir,
                          scope=scope)

        print("Created computation graphs")
        # write config to a file for logging
        if not FLAGS.forward_only:
            with open(os.path.join(log_dir, "run.log"), "wb") as f:
                f.write(pp(config, output=False))

        # create a folder by force
        ckp_dir = os.path.join(log_dir, "checkpoints")
        if not os.path.exists(ckp_dir):
            os.mkdir(ckp_dir)

        ckpt = get_checkpoint_state(ckp_dir)
        print("Created models with fresh parameters.")
        model.apply(lambda m: [
            torch.nn.init.uniform(p.data, -1.0 * config.init_w, config.init_w)
            for p in m.parameters()
        ])

        # Load word2vec weight
        if api.word2vec is not None and not FLAGS.forward_only:
            print("Loaded word2vec")
            model.embedding.weight.data.copy_(
                torch.from_numpy(np.array(api.word2vec)))
        model.embedding.weight.data[0].fill_(0)

        if ckpt:
            print("Reading dm models parameters from %s" % ckpt)
            model.load_state_dict(torch.load(ckpt))

        #

        # turn to cuda
        model.cuda()

        if not FLAGS.forward_only:
            dm_checkpoint_path = os.path.join(
                ckp_dir, model.__class__.__name__ + "-%d.pth")
            global_t = 1
            patience = 10  # wait for at least 10 epoch before stop
            dev_loss_threshold = np.inf
            best_dev_loss = np.inf
            for epoch in range(config.max_epoch):
                print(">> Epoch %d with lr %f" % (epoch, model.learning_rate))

                # begin training
                if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch:
                    train_feed.epoch_init(config.batch_size,
                                          config.backward_size,
                                          config.step_size,
                                          shuffle=True)
                global_t, train_loss = model.train_model(
                    global_t, train_feed, update_limit=config.update_limit)

                # begin validation
                valid_feed.epoch_init(valid_config.batch_size,
                                      valid_config.backward_size,
                                      valid_config.step_size,
                                      shuffle=False,
                                      intra_shuffle=False)
                model.eval()
                valid_loss = model.valid_model("ELBO_VALID", valid_feed)

                test_feed.epoch_init(test_config.batch_size,
                                     test_config.backward_size,
                                     test_config.step_size,
                                     shuffle=True,
                                     intra_shuffle=False)
                model.test_model(test_feed, num_batch=5)
                model.train()

                done_epoch = epoch + 1
                # only save a models if the dev loss is smaller
                # Decrease learning rate if no improvement was seen over last 3 times.
                if config.op == "sgd" and done_epoch > config.lr_hold:
                    model.learning_rate_decay()

                if valid_loss < best_dev_loss:
                    if valid_loss <= dev_loss_threshold * config.improve_threshold:
                        patience = max(patience,
                                       done_epoch * config.patient_increase)
                        dev_loss_threshold = valid_loss

                    # still save the best train model
                    if FLAGS.save_model:
                        print("Save model!!")
                        torch.save(model.state_dict(),
                                   dm_checkpoint_path % (epoch))
                    best_dev_loss = valid_loss

                if config.early_stop and patience <= done_epoch:
                    print("!!Early stop due to run out of patience!!")
                    break
            print("Best validation loss %f" % best_dev_loss)
            print("Done training")
        else:
            # begin validation
            # begin validation
            valid_feed.epoch_init(valid_config.batch_size,
                                  valid_config.backward_size,
                                  valid_config.step_size,
                                  shuffle=False,
                                  intra_shuffle=False)
            model.eval()
            model.valid_model("ELBO_VALID", valid_feed)

            test_feed.epoch_init(valid_config.batch_size,
                                 valid_config.backward_size,
                                 valid_config.step_size,
                                 shuffle=False,
                                 intra_shuffle=False)
            model.valid_model("ELBO_TEST", test_feed)

            dest_f = open(os.path.join(log_dir, "test.txt"), "wb")
            test_feed.epoch_init(test_config.batch_size,
                                 test_config.backward_size,
                                 test_config.step_size,
                                 shuffle=False,
                                 intra_shuffle=False)
            model.test_model(test_feed, num_batch=None, repeat=10, dest=dest_f)
            model.train()
            dest_f.close()
Exemplo n.º 5
0
def main():
    # config for training
    config = Config()

    # config for validation
    valid_config = Config()
    valid_config.keep_prob = 1.0
    valid_config.dec_keep_prob = 1.0
    valid_config.batch_size = 60

    # configuration for testing
    test_config = Config()
    test_config.keep_prob = 1.0
    test_config.dec_keep_prob = 1.0
    test_config.batch_size = 1

    pp(config)

    # get data set
    api = SWDADialogCorpus(FLAGS.data_dir, word2vec=FLAGS.word2vec_path, word2vec_dim=config.embed_size)
    dial_corpus = api.get_dialog_corpus()
    meta_corpus = api.get_meta_corpus()

    '''item in meta like:
     ([0.46, 0.6666666666666666, 1, 0], [0.48, 0.6666666666666666, 1, 0], 29)
    '''
    train_meta, valid_meta, test_meta = meta_corpus.get("train"), meta_corpus.get("valid"), meta_corpus.get("test")
    '''item in dial
    ([utt_id,], spearker_id, [dialogact_id, [,,,]])
    '''
    train_dial, valid_dial, test_dial = dial_corpus.get("train"), dial_corpus.get("valid"), dial_corpus.get("test")

    # convert to numeric input outputs that fits into TF models
    train_feed = SWDADataLoader("Train", train_dial, train_meta, config)
    valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config)
    test_feed = SWDADataLoader("Test", test_dial, test_meta, config)

    if FLAGS.forward_only or FLAGS.resume:
        log_dir = os.path.join(FLAGS.work_dir, FLAGS.test_path)
    else:
        log_dir = os.path.join(FLAGS.work_dir, "run"+str(int(time.time())))

    # begin training
    with tf.Session() as sess:
        initializer = tf.random_uniform_initializer(-1.0 * config.init_w, config.init_w)
        scope = "model"
        with tf.variable_scope(scope, reuse=None, initializer=initializer):
            model = KgRnnCVAE(sess, config, api, log_dir=None if FLAGS.forward_only else log_dir, forward=False, scope=scope)
        with tf.variable_scope(scope, reuse=True, initializer=initializer):
            valid_model = KgRnnCVAE(sess, valid_config, api, log_dir=None, forward=False, scope=scope)
        with tf.variable_scope(scope, reuse=True, initializer=initializer):
            test_model = KgRnnCVAE(sess, test_config, api, log_dir=None, forward=True, scope=scope)

        print("Created computation graphs")
        if api.word2vec is not None and not FLAGS.forward_only:
            print("Loaded word2vec")
            sess.run(model.embedding.assign(np.array(api.word2vec)))

        # write config to a file for logging
        if not FLAGS.forward_only:
            with open(os.path.join(log_dir, "run.log"), "wb") as f:
                f.write(pp(config, output=False))

        # create a folder by force
        ckp_dir = os.path.join(log_dir, "checkpoints")
        if not os.path.exists(ckp_dir):
            os.mkdir(ckp_dir)

        ckpt = tf.train.get_checkpoint_state(ckp_dir)
        print("Created models with fresh parameters.")
        sess.run(tf.global_variables_initializer())

        if ckpt:
            print("Reading dm models parameters from %s" % ckpt.model_checkpoint_path)
            model.saver.restore(sess, ckpt.model_checkpoint_path)

        if not FLAGS.forward_only:
            dm_checkpoint_path = os.path.join(ckp_dir, model.__class__.__name__+ ".ckpt")
            global_t = 1
            patience = 10  # wait for at least 10 epoch before stop
            dev_loss_threshold = np.inf
            best_dev_loss = np.inf
            for epoch in range(config.max_epoch):
                print(">> Epoch %d with lr %f" % (epoch, model.learning_rate.eval()))

                # begin training
                if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch:
                    train_feed.epoch_init(config.batch_size, config.backward_size,
                                          config.step_size, shuffle=True)
                global_t, train_loss = model.train(global_t, sess, train_feed, update_limit=config.update_limit)

                # begin validation
                valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                      valid_config.step_size, shuffle=False, intra_shuffle=False)
                valid_loss = valid_model.valid("ELBO_VALID", sess, valid_feed)

                test_feed.epoch_init(test_config.batch_size, test_config.backward_size,
                                     test_config.step_size, shuffle=True, intra_shuffle=False)
                test_model.test(sess, test_feed, num_batch=5)

                done_epoch = epoch + 1
                # only save a models if the dev loss is smaller
                # Decrease learning rate if no improvement was seen over last 3 times.
                if config.op == "sgd" and done_epoch > config.lr_hold:
                    sess.run(model.learning_rate_decay_op)

                if valid_loss < best_dev_loss:
                    if valid_loss <= dev_loss_threshold * config.improve_threshold:
                        patience = max(patience, done_epoch * config.patient_increase)
                        dev_loss_threshold = valid_loss

                    # still save the best train model
                    if FLAGS.save_model:
                        print("Save model!!")
                        model.saver.save(sess, dm_checkpoint_path, global_step=epoch)
                    best_dev_loss = valid_loss

                if config.early_stop and patience <= done_epoch:
                    print("!!Early stop due to run out of patience!!")
                    break
            print("Best validation loss %f" % best_dev_loss)
            print("Done training")
        elif not FLAGS.demo:
            # begin validation
            # begin validation
            valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                  valid_config.step_size, shuffle=False, intra_shuffle=False)
            valid_model.valid("ELBO_VALID", sess, valid_feed)

            test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                  valid_config.step_size, shuffle=False, intra_shuffle=False)
            valid_model.valid("ELBO_TEST", sess, test_feed)

            dest_f = open(os.path.join(log_dir, "test.txt"), "wb")
            test_feed.epoch_init(test_config.batch_size, test_config.backward_size,
                                 test_config.step_size, shuffle=False, intra_shuffle=False)
            test_model.test(sess, test_feed, num_batch=None, repeat=10, dest=dest_f)
            dest_f.close()
        else:
            infer_api = Corpus(FLAGS.vocab, FLAGS.da_vocab, FLAGS.topic_vocab)
            while(1):
                demo_sent = raw_input('Please input your sentence:\t')
                if demo_sent == '' or demo_sent.isspace():
                    print('See you next time!')
                    break
                speaker = raw_input('Are you speaker A?: 1 or 0\t')
                topic = raw_input('what is your topic:\t')
                # demo_sent = 'Hello'
                # speaker = '1'
                # topic = 'MUSIC'
                meta, dial = infer_api.format_input(demo_sent, speaker, topic)
                infer_feed = SWDADataLoader("Infer", dial, meta, config)
                infer_feed.epoch_init(test_config.batch_size, test_config.backward_size,
                                 test_config.step_size, shuffle=False, intra_shuffle=False)
                pred_strs, pred_das = test_model.inference(sess, infer_feed, num_batch=None, repeat=3)
                if pred_strs and pred_das:
                    print('Here is your answer')
                    for da, answer in zip(pred_strs, pred_das):
                        print('{} >> {}'.format(da, answer))
Exemplo n.º 6
0
def main():
    ## random seeds
    seed = FLAGS.seed
    # tf.random.set_seed(seed)
    np.random.seed(seed)

    ## config for training
    config = Config()
    pid = PIDControl(FLAGS.exp_KL)
    
    # config for validation
    valid_config = Config()
    valid_config.keep_prob = 1.0
    valid_config.dec_keep_prob = 1.0
    valid_config.batch_size = 60

    # configuration for testing
    test_config = Config()
    test_config.keep_prob = 1.0
    test_config.dec_keep_prob = 1.0
    test_config.batch_size = 1

    pp(config)

    # get data set
    api = SWDADialogCorpus(FLAGS.data_dir, word2vec=FLAGS.word2vec_path, word2vec_dim=config.embed_size)
    dial_corpus = api.get_dialog_corpus()
    meta_corpus = api.get_meta_corpus()

    train_meta, valid_meta, test_meta = meta_corpus.get("train"), meta_corpus.get("valid"), meta_corpus.get("test")
    train_dial, valid_dial, test_dial = dial_corpus.get("train"), dial_corpus.get("valid"), dial_corpus.get("test")
    
    # convert to numeric input outputs that fits into TF models
    train_feed = SWDADataLoader("Train", train_dial, train_meta, config)
    valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config)
    test_feed = SWDADataLoader("Test", test_dial, test_meta, config)

    if FLAGS.forward_only or FLAGS.resume:
        # log_dir = os.path.join(FLAGS.work_dir, FLAGS.test_path)
        log_dir = os.path.join(FLAGS.work_dir, FLAGS.model_name)
    else:
        log_dir = os.path.join(FLAGS.work_dir, FLAGS.model_name)

    
    ## begin training
    with tf.Session() as sess:
        initializer = tf.random_uniform_initializer(-1.0 * config.init_w, config.init_w)
        scope = "model"
        with tf.variable_scope(scope, reuse=None, initializer=initializer):
            model = KgRnnCVAE(sess, config, api, log_dir=None if FLAGS.forward_only else log_dir, forward=False, pid_control=pid, scope=scope)
        with tf.variable_scope(scope, reuse=True, initializer=initializer):
            valid_model = KgRnnCVAE(sess, valid_config, api, log_dir=None, forward=False, pid_control=pid, scope=scope)
        with tf.variable_scope(scope, reuse=True, initializer=initializer):
            test_model = KgRnnCVAE(sess, test_config, api, log_dir=None, forward=True, pid_control=pid, scope=scope)

        print("Created computation graphs")
        if api.word2vec is not None and not FLAGS.forward_only:
            print("Loaded word2vec")
            sess.run(model.embedding.assign(np.array(api.word2vec)))

        # write config to a file for logging
        if not FLAGS.forward_only:
            with open(os.path.join(log_dir, "configure.log"), "wb") as f:
                f.write(pp(config, output=False))
        
        # create a folder by force
        ckp_dir = os.path.join(log_dir, "checkpoints")
        print("*******checkpoint path: ", ckp_dir)
        if not os.path.exists(ckp_dir):
            os.mkdir(ckp_dir)

        ckpt = tf.train.get_checkpoint_state(ckp_dir)
        print("Created models with fresh parameters.")
        sess.run(tf.global_variables_initializer())

        if ckpt:
            print("Reading dm models parameters from %s" % ckpt.model_checkpoint_path)
            model.saver.restore(sess, ckpt.model_checkpoint_path)
        ### save log when running
        if not FLAGS.forward_only:
            logfileName = "train.log"
        else:
            logfileName = "test.log"
        fw_log = open(os.path.join(log_dir, logfileName), "w")
        print("log directory >>> : ", os.path.join(log_dir, "run.log"))
        if not FLAGS.forward_only:
            print('--start training now---')
            dm_checkpoint_path = os.path.join(ckp_dir, model.__class__.__name__+ ".ckpt")
            global_t = 1
            patience = 20  # wait for at least 10 epoch before stop
            dev_loss_threshold = np.inf
            best_dev_loss = np.inf
            pbar = tqdm(total = config.max_epoch)
            ## epoch start training
            for epoch in range(config.max_epoch):
                pbar.update(1)
                print(">> Epoch %d with lr %f" % (epoch, model.learning_rate.eval()))

                ## begin training
                FLAGS.mode = 'train'
                if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch:
                    train_feed.epoch_init(config.batch_size, config.backward_size,
                                          config.step_size, shuffle=True)
                global_t, train_loss = model.train(global_t, sess, train_feed, update_limit=config.update_limit)
                
                FLAGS.mode = 'valid'
                valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                  valid_config.step_size, shuffle=False, intra_shuffle=False)
                test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                  valid_config.step_size, shuffle=False, intra_shuffle=False)
                elbo, nll, ppl, au_count, kl_loss = valid_model.valid("ELBO_TEST", sess, valid_feed, test_feed)
                print('middle test nll: {} ppl: {} ActiveUnit: {} kl_loss:{}\n'.format(nll, ppl,au_count,kl_loss))
                fw_log.write('epoch:{} testing nll:{} ppl:{} ActiveUnit:{} kl_loss:{} elbo:{}\n'.\
                            format(epoch, nll, ppl, au_count, kl_loss, elbo))
                fw_log.flush()
                
                '''
                ## begin validation
                FLAGS.mode = 'valid'
                valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                      valid_config.step_size, shuffle=False, intra_shuffle=False)
                valid_loss = valid_model.valid("ELBO_VALID", sess, valid_feed)

                ## test model
                FLAGS.mode = 'test'
                test_feed.epoch_init(test_config.batch_size, test_config.backward_size,
                                     test_config.step_size, shuffle=True, intra_shuffle=False)
                test_model.test(sess, test_feed, num_batch=5)

                done_epoch = epoch + 1
                # only save a models if the dev loss is smaller
                # Decrease learning rate if no improvement was seen over last 3 times.
                if config.op == "sgd" and done_epoch > config.lr_hold:
                    sess.run(model.learning_rate_decay_op)

                if valid_loss < best_dev_loss:
                    if valid_loss <= dev_loss_threshold * config.improve_threshold:
                        patience = max(patience, done_epoch * config.patient_increase)
                        dev_loss_threshold = valid_loss

                    # still save the best train model
                    if FLAGS.save_model:
                        print("Save model!!")
                        model.saver.save(sess, dm_checkpoint_path, global_step=epoch)
                    best_dev_loss = valid_loss

                if config.early_stop and patience <= done_epoch:
                    print("!!Early stop due to run out of patience!!")
                    break
                    ## print("Best validation loss %f" % best_dev_loss)
                 '''
            print("Done training and save checkpoint")

            if FLAGS.save_model:
                print("Save model!!")
                model.saver.save(sess, dm_checkpoint_path, global_step=epoch)
            # begin validation
            print('--------after training to testing now-----')
            FLAGS.mode = 'test'
            # valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                #   valid_config.step_size, shuffle=False, intra_shuffle=False)
            # valid_model.valid("ELBO_VALID", sess, valid_feed)
            valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                  valid_config.step_size, shuffle=False, intra_shuffle=False)
            
            test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                  valid_config.step_size, shuffle=False, intra_shuffle=False)
            elbo, nll, ppl, au_count,kl_loss = valid_model.valid("ELBO_TEST", sess, valid_feed, test_feed)

            print('final test nll: {} ppl: {} ActiveUnit: {} kl_loss:{}\n'.format(nll, ppl,au_count,kl_loss))
            fw_log.write('Final testing nll:{} ppl:{} ActiveUnit:{} kl_loss:{} elbo:{}\n'.\
                            format(nll, ppl, au_count, kl_loss, elbo))
            
            dest_f = open(os.path.join(log_dir, FLAGS.test_res), "wb")
            test_feed.epoch_init(test_config.batch_size, test_config.backward_size,
                                 test_config.step_size, shuffle=False, intra_shuffle=False)
            test_model.test(sess, test_feed, num_batch=None, repeat=10, dest=dest_f)
            dest_f.close()
            print("****testing done****")
        else:
            # begin validation
            # begin validation
            print('*'*89)
            print('--------testing now-----')
            print('*'*89)
            FLAGS.mode = 'test'
            valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                  valid_config.step_size, shuffle=False, intra_shuffle=False)
            # valid_model.valid("ELBO_VALID", sess, valid_feed)

            test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                  valid_config.step_size, shuffle=False, intra_shuffle=False)
            elbo, nll, ppl, au_count, kl_loss = valid_model.valid("ELBO_TEST", sess, valid_feed, test_feed)

            print('final test nll: {} ppl: {} ActiveUnit: {} kl_loss:{}\n'.format(nll, ppl,au_count,kl_loss))
            fw_log.write('Final testing nll:{} ppl:{} ActiveUnit:{} kl_loss:{} elbo:{}\n'.\
                            format(nll, ppl, au_count, kl_loss, elbo))
            # dest_f = open(os.path.join(log_dir, FLAGS.test_res), "wb")
            # test_feed.epoch_init(test_config.batch_size, test_config.backward_size,
            #                      test_config.step_size, shuffle=False, intra_shuffle=False)
            # test_model.test(sess, test_feed, num_batch=None, repeat=10, dest=dest_f)
            # dest_f.close()
            print("****testing done****")
        fw_log.close()
Exemplo n.º 7
0
def main():
    # config for training
    config = Config()
    print("Normal train config:")
    pp(config)

    valid_config = Config()
    valid_config.dropout = 0
    valid_config.batch_size = 20

    # config for test
    test_config = Config()
    test_config.dropout = 0
    test_config.batch_size = 1

    with_sentiment = config.with_sentiment

    pretrain = False

    ###############################################################################
    # Logs
    ###############################################################################
    log_start_time = str(datetime.now().strftime('%Y%m%d%H%M'))
    if not os.path.isdir('./output'):
        os.makedirs('./output')
    if not os.path.isdir('./output/{}'.format(args.expname)):
        os.makedirs('./output/{}'.format(args.expname))
    if not os.path.isdir('./output/{}/{}'.format(args.expname,
                                                 log_start_time)):
        os.makedirs('./output/{}/{}'.format(args.expname, log_start_time))

    # save arguments
    json.dump(
        vars(args),
        open('./output/{}/{}/args.json'.format(args.expname, log_start_time),
             'w'))

    logger = logging.getLogger(__name__)
    logging.basicConfig(level=logging.DEBUG, format="%(message)s")
    fh = logging.FileHandler("./output/{}/{}/logs.txt".format(
        args.expname, log_start_time))
    # add the handlers to the logger
    logger.addHandler(fh)
    logger.info(vars(args))

    tb_writer = SummaryWriter("./output/{}/{}/tb_logs".format(
        args.expname, log_start_time)) if args.visual else None

    ###############################################################################
    # Model
    ###############################################################################
    # vocab and rev_vocab
    with open(args.vocab_path) as vocab_file:
        vocab = vocab_file.read().strip().split('\n')
        rev_vocab = {vocab[idx]: idx for idx in range(len(vocab))}

    if not pretrain:
        pass
        # assert config.reload_model
        # model = load_model(config.model_name)
    else:
        if args.model == "multiVAE":
            model = multiVAE(config=config, vocab=vocab, rev_vocab=rev_vocab)
        else:
            model = CVAE(config=config, vocab=vocab, rev_vocab=rev_vocab)
        if use_cuda:
            model = model.cuda()
    ###############################################################################
    # Load data
    ###############################################################################

    if pretrain:
        from collections import defaultdict
        api = LoadPretrainPoem(corpus_path=args.pretrain_data_dir,
                               vocab_path="data/vocab.txt")

        train_corpus, valid_corpus = defaultdict(list), defaultdict(list)
        divide = 50000
        train_corpus['pos'], valid_corpus['pos'] = api.data[
            'pos'][:divide], api.data['pos'][divide:]
        train_corpus['neu'], valid_corpus['neu'] = api.data[
            'neu'][:divide], api.data['neu'][divide:]
        train_corpus['neg'], valid_corpus['neg'] = api.data[
            'neg'][:divide], api.data['neg'][divide:]

        token_corpus = defaultdict(dict)
        token_corpus['pos'], token_corpus['neu'], token_corpus['neg'] = \
            api.get_tokenized_poem_corpus(train_corpus['pos'], valid_corpus['pos']), \
            api.get_tokenized_poem_corpus(train_corpus['neu'], valid_corpus['neu']), \
            api.get_tokenized_poem_corpus(train_corpus['neg'], valid_corpus['neg']),
        # train_loader_dict = {'pos': }

        train_loader = {
            'pos': SWDADataLoader("Train", token_corpus['pos']['train'],
                                  config),
            'neu': SWDADataLoader("Train", token_corpus['neu']['train'],
                                  config),
            'neg': SWDADataLoader("Train", token_corpus['neg']['train'],
                                  config)
        }

        valid_loader = {
            'pos': SWDADataLoader("Train", token_corpus['pos']['valid'],
                                  config),
            'neu': SWDADataLoader("Train", token_corpus['neu']['valid'],
                                  config),
            'neg': SWDADataLoader("Train", token_corpus['neg']['valid'],
                                  config)
        }
        ###############################################################################
        # Pretrain three VAEs
        ###############################################################################

        epoch_id = 0
        global_t = 0
        init_train_loaders(train_loader, config)
        while epoch_id < config.epochs:

            while True:  # loop through all batches in training data
                # train一个batch

                model, finish_train, loss_records, global_t = \
                    pre_train_process(global_t=global_t, model=model, train_loader=train_loader)
                if finish_train:
                    if epoch_id > 5:
                        save_model(model=model,
                                   epoch=epoch_id,
                                   global_t=global_t,
                                   log_start_time=log_start_time)
                    epoch_id += 1
                    init_train_loaders(train_loader, config)
                    break
                # 写一下log
                if global_t % config.log_every == 0:
                    pre_log_process(epoch_id=epoch_id,
                                    global_t=global_t,
                                    train_loader=train_loader,
                                    loss_records=loss_records,
                                    logger=logger,
                                    tb_writer=tb_writer)

                # valid
                if global_t % config.valid_every == 0:
                    # test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger)
                    pre_valid_process(global_t=global_t,
                                      model=model,
                                      valid_loader=valid_loader,
                                      valid_config=valid_config,
                                      tb_writer=tb_writer,
                                      logger=logger)
                if global_t % config.test_every == 0:
                    pre_test_process(model=model, logger=logger)
    ###############################################################################
    # Train the big model
    ###############################################################################
    api = LoadPoem(corpus_path=args.train_data_dir,
                   vocab_path="data/vocab.txt",
                   test_path=args.test_data_dir,
                   max_vocab_cnt=config.max_vocab_cnt,
                   with_sentiment=with_sentiment)
    from collections import defaultdict
    token_corpus = defaultdict(dict)
    token_corpus['pos'], token_corpus['neu'], token_corpus['neg'] = \
        api.get_tokenized_poem_corpus(api.train_corpus['pos'], api.valid_corpus['pos']), \
        api.get_tokenized_poem_corpus(api.train_corpus['neu'], api.valid_corpus['neu']), \
        api.get_tokenized_poem_corpus(api.train_corpus['neg'], api.valid_corpus['neg']),

    train_loader = {
        'pos': SWDADataLoader("Train", token_corpus['pos']['train'], config),
        'neu': SWDADataLoader("Train", token_corpus['neu']['train'], config),
        'neg': SWDADataLoader("Train", token_corpus['neg']['train'], config)
    }

    valid_loader = {
        'pos': SWDADataLoader("Train", token_corpus['pos']['valid'], config),
        'neu': SWDADataLoader("Train", token_corpus['neu']['valid'], config),
        'neg': SWDADataLoader("Train", token_corpus['neg']['valid'], config)
    }
    test_poem = api.get_tokenized_test_corpus()['test']  # 测试数据
    test_loader = SWDADataLoader("Test", test_poem, config)

    print("Finish Poem data loading, not pretraining or alignment test")

    if not args.forward_only:
        # model依然是PoemWAE_GMP保持不变,只不过,用这部分数据强制训练其中一个高斯先验分布
        # pretrain = True

        cur_best_score = {
            'min_valid_loss': 100,
            'min_global_itr': 0,
            'min_epoch': 0,
            'min_itr': 0
        }

        # model = load_model(3, 3)
        epoch_id = 0
        global_t = 0
        init_train_loaders(train_loader, config)
        while epoch_id < config.epochs:

            while True:  # loop through all batches in training data
                # train一个batch
                model, finish_train, loss_records, global_t = \
                    train_process(global_t=global_t, model=model, train_loader=train_loader)
                if finish_train:
                    if epoch_id > 5:
                        save_model(model=model,
                                   epoch=epoch_id,
                                   global_t=global_t,
                                   log_start_time=log_start_time)
                    epoch_id += 1
                    init_train_loaders(train_loader, config)
                    break

                # 写一下log
                if global_t % config.log_every == 0:
                    pre_log_process(epoch_id=epoch_id,
                                    global_t=global_t,
                                    train_loader=train_loader,
                                    loss_records=loss_records,
                                    logger=logger,
                                    tb_writer=tb_writer)

                # valid
                if global_t % config.valid_every == 0:
                    valid_process(global_t=global_t,
                                  model=model,
                                  valid_loader=valid_loader,
                                  valid_config=valid_config,
                                  tb_writer=tb_writer,
                                  logger=logger)
                # if batch_idx % (train_loader.num_batch // 3) == 0:
                #     test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger)
                if global_t % config.test_every == 0:
                    test_process(model=model,
                                 test_loader=test_loader,
                                 test_config=test_config,
                                 logger=logger)

        # forward_only 测试
    else:
        expname = 'trainVAE'
        time = '202101231631'

        model = load_model(
            './output/{}/{}/model_global_t_26250_epoch9.pckl'.format(
                expname, time))
        test_loader.epoch_init(test_config.batch_size, shuffle=False)
        if not os.path.exists('./output/{}/{}/test/'.format(expname, time)):
            os.mkdir('./output/{}/{}/test/'.format(expname, time))
        output_file = [
            open('./output/{}/{}/test/output_0.txt'.format(expname, time),
                 'w'),
            open('./output/{}/{}/test/output_1.txt'.format(expname, time),
                 'w'),
            open('./output/{}/{}/test/output_2.txt'.format(expname, time), 'w')
        ]
        poem_count = 0
        predict_results = {0: [], 1: [], 2: []}
        titles = {0: [], 1: [], 2: []}
        sentiment_result = {0: [], 1: [], 2: []}
        # sent_dict = {0: ['0', '1', '1', '0'], 1: ['2', '1', '2', '2'], 2: ['1', '0', '1', '2']}
        sent_dict = {
            0: ['0', '0', '0', '0'],
            1: ['1', '1', '1', '1'],
            2: ['2', '2', '2', '2']
        }
        # Get all poem predictions
        while True:
            model.eval()
            batch = test_loader.next_batch_test()  # test data使用专门的batch
            poem_count += 1
            if poem_count % 10 == 0:
                print("Predicted {} poems".format(poem_count))
            if batch is None:
                break
            title_list = batch  # batch size是1,一个batch写一首诗
            title_tensor = to_tensor(title_list)
            # test函数将当前batch对应的这首诗decode出来,记住每次decode的输入context是上一次的结果
            for i in range(3):
                sent_labels = sent_dict[i]
                for _ in range(4):
                    sent_labels.append(str(i))

                output_poem, output_tokens = model.test(
                    title_tensor, title_list, sent_labels=sent_labels)

                titles[i].append(output_poem.strip().split('\n')[0])
                predict_results[i] += (np.array(output_tokens)[:, :7].tolist())

        # Predict sentiment use the sort net
        from collections import defaultdict
        neg = defaultdict(int)
        neu = defaultdict(int)
        pos = defaultdict(int)
        total = defaultdict(int)
        for i in range(3):
            cur_sent_result, neg[i], neu[i], pos[i] = test_sentiment(
                predict_results[i])
            sentiment_result[i] = cur_sent_result
            total[i] = neg[i] + neu[i] + pos[i]

        for i in range(3):
            print("%d%%\t%d%%\t%d%%" % (neg[i] * 100 / total[i], neu[i] * 100 /
                                        total[i], pos[i] * 100 / total[i]))

        for i in range(3):
            write_predict_result_to_file(titles[i], predict_results[i],
                                         sentiment_result[i], output_file[i])
            output_file[i].close()

        print("Done testing")
Exemplo n.º 8
0
def main(model_type):
    # config for training
    config = Config()

    # config for validation
    valid_config = Config()
    valid_config.keep_prob = 1.0
    valid_config.dec_keep_prob = 1.0
    valid_config.batch_size = 60

    # configuration for testing
    test_config = Config()
    test_config.keep_prob = 1.0
    test_config.dec_keep_prob = 1.0
    test_config.batch_size = 1

    pp(config)

    # which model to run
    if model_type == "kgcvae":
        model_class = KgRnnCVAE
        backward_size = config.backward_size
        config.use_hcf, valid_config.use_hcf, test_config.use_hcf = True
    elif model_type == "cvae":
        model_class = KgRnnCVAE
        backward_size = config.backward_size
        config.use_hcf, valid_config.use_hcf, test_config.use_hcf = False
    elif model_type == 'hierbaseline':
        model_class = HierBaseline
        backward_size = config.backward_size
    else:
        raise ValueError("This shouldn't happen.")

    # LDA Model
    ldamodel = LDAModel(config,
                        trained_model_path=FLAGS.lda_model_path,
                        id2word_path=FLAGS.id2word_path)

    # get data set
    api = SWDADialogCorpus(FLAGS.data_dir,
                           word2vec=FLAGS.word2vec_path,
                           word2vec_dim=config.embed_size,
                           vocab_dict_path=FLAGS.vocab_dict_path,
                           lda_model=ldamodel,
                           imdb=FLAGS.use_imdb)

    dial_corpus = api.get_dialog_corpus()
    meta_corpus = api.get_meta_corpus()

    train_meta, valid_meta, test_meta = meta_corpus.get(
        "train"), meta_corpus.get("valid"), meta_corpus.get("test")
    train_dial, valid_dial, test_dial = dial_corpus.get(
        "train"), dial_corpus.get("valid"), dial_corpus.get("test")

    # convert to numeric input outputs that fits into TF models
    train_feed = SWDADataLoader("Train", train_dial, train_meta, config)
    valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config)
    test_feed = SWDADataLoader("Test", test_dial, test_meta, config)

    # if you're testing an existing implementation or resuming training
    if FLAGS.forward_only or FLAGS.resume:
        log_dir = os.path.join(FLAGS.work_dir + "_" + FLAGS.model_type,
                               FLAGS.test_path)
    else:
        log_dir = os.path.join(FLAGS.work_dir + "_" + FLAGS.model_type,
                               "run" + str(int(time.time())))

    # begin training
    with tf.Session() as sess:
        initializer = tf.random_uniform_initializer(-1.0 * config.init_w,
                                                    config.init_w)
        scope = "model"
        with tf.variable_scope(scope, reuse=None, initializer=initializer):
            model = model_class(
                sess,
                config,
                api,
                log_dir=None if FLAGS.forward_only else log_dir,
                forward=False,
                scope=scope)
        with tf.variable_scope(scope, reuse=True, initializer=initializer):
            valid_model = model_class(sess,
                                      valid_config,
                                      api,
                                      log_dir=None,
                                      forward=False,
                                      scope=scope)
        with tf.variable_scope(scope, reuse=True, initializer=initializer):
            test_model = model_class(sess,
                                     test_config,
                                     api,
                                     log_dir=None,
                                     forward=True,
                                     scope=scope)

        print("Created computation graphs")
        if api.word2vec is not None and not FLAGS.forward_only:
            print("Loaded word2vec")
            sess.run(model.embedding.assign(np.array(api.word2vec)))

        # write config to a file for logging
        if not FLAGS.forward_only:
            with open(os.path.join(log_dir, "run.log"), "wb") as f:
                f.write(pp(config, output=False))

        # create a folder by force
        ckp_dir = os.path.join(log_dir, "checkpoints")
        if not os.path.exists(ckp_dir):
            os.mkdir(ckp_dir)

        ckpt = tf.train.get_checkpoint_state(ckp_dir)
        print("Created models with fresh parameters.")
        sess.run(tf.global_variables_initializer())

        if ckpt:
            print("Reading dm models parameters from %s" %
                  ckpt.model_checkpoint_path)
            model.saver.restore(sess, ckpt.model_checkpoint_path)

        # if you're training a model
        if not FLAGS.forward_only:

            dm_checkpoint_path = os.path.join(
                ckp_dir, model.__class__.__name__ + ".ckpt")
            global_t = 1
            patience = 10  # wait for at least 10 epoch before stop
            dev_loss_threshold = np.inf
            best_dev_loss = np.inf

            # train for a max of max_epoch's. saves the model after the epoch if it's some amount better than current best
            for epoch in range(config.max_epoch):
                print(">> Epoch %d with lr %f" %
                      (epoch, model.learning_rate.eval()))

                # begin training
                if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch:
                    train_feed.epoch_init(config.batch_size,
                                          backward_size,
                                          config.step_size,
                                          shuffle=True)

                global_t, train_loss = model.train(
                    global_t,
                    sess,
                    train_feed,
                    update_limit=config.update_limit)

                # begin validation and testing
                valid_feed.epoch_init(valid_config.batch_size,
                                      valid_config.backward_size,
                                      valid_config.step_size,
                                      shuffle=False,
                                      intra_shuffle=False)
                valid_loss = valid_model.valid("ELBO_VALID", sess, valid_feed)

                test_feed.epoch_init(test_config.batch_size,
                                     test_config.backward_size,
                                     test_config.step_size,
                                     shuffle=True,
                                     intra_shuffle=False)
                test_model.test(
                    sess, test_feed, num_batch=1
                )  #TODO change this batch size back to a reasonably large number

                done_epoch = epoch + 1
                # only save a models if the dev loss is smaller
                # Decrease learning rate if no improvement was seen over last 3 times.
                if config.op == "sgd" and done_epoch > config.lr_hold:
                    sess.run(model.learning_rate_decay_op)

                if True:  #valid_loss < best_dev_loss: # TODO this change makes the model always save. Change this back when corpus not trivial
                    if True:  #valid_loss <= dev_loss_threshold * config.improve_threshold:
                        patience = max(patience,
                                       done_epoch * config.patient_increase)
                        # dev_loss_threshold = valid_loss

                    # still save the best train model
                    if FLAGS.save_model:
                        print("Save model!!")
                        model.saver.save(sess,
                                         dm_checkpoint_path,
                                         global_step=epoch)
                    # best_dev_loss = valid_loss

                if config.early_stop and patience <= done_epoch:
                    print("!!Early stop due to run out of patience!!")
                    break
            # print("Best validation loss %f" % best_dev_loss)
            print("Done training")

        # else if you're just testing an existing model
        else:
            # begin validation
            valid_feed.epoch_init(valid_config.batch_size,
                                  valid_config.backward_size,
                                  valid_config.step_size,
                                  shuffle=False,
                                  intra_shuffle=False)
            valid_model.valid("ELBO_VALID", sess, valid_feed)

            test_feed.epoch_init(valid_config.batch_size,
                                 valid_config.backward_size,
                                 valid_config.step_size,
                                 shuffle=False,
                                 intra_shuffle=False)
            valid_model.valid("ELBO_TEST", sess, test_feed)

            # begin testing
            dest_f = open(os.path.join(log_dir, "test.txt"), "wb")
            test_feed.epoch_init(test_config.batch_size,
                                 test_config.backward_size,
                                 test_config.step_size,
                                 shuffle=False,
                                 intra_shuffle=False)
            test_model.test(sess,
                            test_feed,
                            num_batch=None,
                            repeat=10,
                            dest=dest_f)
            dest_f.close()
Exemplo n.º 9
0
def main():
    # config for training
    config = Config()

    # config for validation
    valid_config = Config()
    # valid_config.keep_prob = 1.0
    # valid_config.dec_keep_prob = 1.0
    # valid_config.batch_size = 60

    # configuration for testing
    test_config = Config()
    test_config.keep_prob = 1.0
    test_config.dec_keep_prob = 1.0
    test_config.batch_size = 1

    config.n_state = FLAGS.n_state
    valid_config.n_state = FLAGS.n_state
    test_config.n_state = FLAGS.n_state

    config.with_direct_transition = FLAGS.with_direct_transition
    valid_config.with_direct_transition = FLAGS.with_direct_transition
    test_config.with_direct_transition = FLAGS.with_direct_transition

    config.with_word_weights = FLAGS.with_word_weights
    valid_config.with_word_weights = FLAGS.with_word_weights
    test_config.with_word_weights = FLAGS.with_word_weights

    pp(config)

    print(config.n_state)
    print(config.with_direct_transition)
    print(config.with_word_weights)
    # get data set
    # api = SWDADialogCorpus(FLAGS.data_dir, word2vec=FLAGS.word2vec_path, word2vec_dim=config.embed_size)
    with open(config.api_dir, "r") as fh:
        api = pkl.load(fh)
    dial_corpus = api.get_dialog_corpus()
    if config.with_label_loss:
        labeled_dial_labels = api.get_state_corpus(
            config.max_dialog_len)['labeled']
    # meta_corpus = api.get_meta_corpus()

    # train_meta, valid_meta, test_meta = meta_corpus.get("train"), meta_corpus.get("valid"), meta_corpus.get("test")
    train_dial, labeled_dial, test_dial = dial_corpus.get(
        "train"), dial_corpus.get("labeled"), dial_corpus.get("test")

    # convert to numeric input outputs that fits into TF models
    train_feed = SWDADataLoader("Train", train_dial, config)
    # valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config)
    test_feed = SWDADataLoader("Test", test_dial, config)
    if config.with_label_loss:
        labeled_feed = SWDADataLoader("Labeled",
                                      labeled_dial,
                                      config,
                                      labeled=True)
    valid_feed = test_feed

    if FLAGS.forward_only or FLAGS.resume:
        log_dir = os.path.join(FLAGS.work_dir, FLAGS.test_path)
    else:
        log_dir = os.path.join(FLAGS.work_dir, "run" + str(int(time.time())))

    # begin training
    with tf.Session(config=tf.ConfigProto(log_device_placement=True,
                                          allow_soft_placement=True)) as sess:
        initializer = tf.random_uniform_initializer(-1.0 * config.init_w,
                                                    config.init_w)
        scope = "model"
        with tf.variable_scope(scope, reuse=None, initializer=initializer):
            model = VRNN(sess,
                         config,
                         api,
                         log_dir=None if FLAGS.forward_only else log_dir,
                         scope=scope)
        with tf.variable_scope(scope, reuse=True, initializer=initializer):
            valid_model = VRNN(sess,
                               valid_config,
                               api,
                               log_dir=None,
                               scope=scope)
        #with tf.variable_scope(scope, reuse=True, initializer=initializer):
        #    test_model = KgRnnCVAE(sess, test_config, api, log_dir=None, forward=True, scope=scope)

        print("Created computation graphs")
        if api.word2vec is not None and not FLAGS.forward_only:
            print("Loaded word2vec")
            sess.run(model.W_embedding.assign(np.array(api.word2vec)))

        # write config to a file for logging
        if not FLAGS.forward_only:
            with open(os.path.join(log_dir, "run.log"), "wb") as f:
                f.write(pp(config, output=False))

        # create a folder by force
        ckp_dir = os.path.join(log_dir, "checkpoints")
        if not os.path.exists(ckp_dir):
            os.mkdir(ckp_dir)

        ckpt = tf.train.get_checkpoint_state(ckp_dir)
        print("Created models with fresh parameters.")
        sess.run(tf.global_variables_initializer())

        if ckpt:
            print("Reading dm models parameters from %s" %
                  ckpt.model_checkpoint_path)
            model.saver.restore(sess, ckpt.model_checkpoint_path)
            #print([str(op.name) for op in tf.get_default_graph().get_operations()])
            print([str(v.name) for v in tf.global_variables()])
            import sys
            # sys.exit()

        if not FLAGS.forward_only:
            dm_checkpoint_path = os.path.join(
                ckp_dir, model.__class__.__name__ + ".ckpt")
            global_t = 1
            patience = config.n_epoch  # wait for at least 10 epoch before stop
            dev_loss_threshold = np.inf
            best_dev_loss = np.inf
            for epoch in range(config.max_epoch):
                print(">> Epoch %d with lr %f" %
                      (epoch, model.learning_rate.eval()))

                # begin training
                if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch:
                    train_feed.epoch_init(config.batch_size, shuffle=True)

                if config.with_label_loss:
                    labeled_feed.epoch_init(len(labeled_dial), shuffle=False)
                else:
                    labeled_feed = None
                    labeled_dial_labels = None
                global_t, train_loss = model.train(
                    global_t,
                    sess,
                    train_feed,
                    labeled_feed,
                    labeled_dial_labels,
                    update_limit=config.update_limit)

                # begin validation
                valid_feed.epoch_init(config.batch_size, shuffle=False)
                valid_loss = valid_model.valid("ELBO_VALID", sess, valid_feed,
                                               labeled_feed,
                                               labeled_dial_labels)
                """
                test_feed.epoch_init(test_config.batch_size, test_config.backward_size,
                                     test_config.step_size, shuffle=True, intra_shuffle=False)
                test_model.test(sess, test_feed, num_batch=5)
                """

                done_epoch = epoch + 1
                # only save a models if the dev loss is smaller
                # Decrease learning rate if no improvement was seen over last 3 times.
                if config.op == "sgd" and done_epoch > config.lr_hold:
                    sess.run(model.learning_rate_decay_op)
                """
                if valid_loss < best_dev_loss:
                    if valid_loss <= dev_loss_threshold * config.improve_threshold:
                        patience = max(patience, done_epoch * config.patient_increase)
                        dev_loss_threshold = valid_loss

                    # still save the best train model
                    if FLAGS.save_model:
                        print("Save model!!")
                        model.saver.save(sess, dm_checkpoint_path, global_step=epoch)
                    best_dev_loss = valid_loss
                """
                # still save the best train model
                if FLAGS.save_model:
                    print("Save model!!")
                    model.saver.save(sess,
                                     dm_checkpoint_path,
                                     global_step=epoch)

                if config.early_stop and patience <= done_epoch:
                    print("!!Early stop due to run out of patience!!")
                    break
            print("Best validation loss %f" % best_dev_loss)
            print("Done training")
        else:
            # begin validation
            # begin validation
            global_t = 1
            for epoch in range(1):
                print("test-----------")
                print(">> Epoch %d with lr %f" %
                      (epoch, model.learning_rate.eval()))

            if not FLAGS.use_test_batch:
                # begin training
                if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch:
                    train_feed.epoch_init(config.batch_size, shuffle=False)
                if config.with_label_loss:
                    labeled_feed.epoch_init(len(labeled_dial), shuffle=False)
                else:
                    labeled_feed = None
                    labeled_dial_labels = None
                results, fetch_results = model.get_zt(
                    global_t,
                    sess,
                    train_feed,
                    update_limit=config.update_limit,
                    labeled_feed=labeled_feed,
                    labeled_labels=labeled_dial_labels)
                with open(FLAGS.result_path, "w") as fh:
                    pkl.dump(results, fh)
                with open(FLAGS.result_path + ".param.pkl", "w") as fh:
                    pkl.dump(fetch_results, fh)
            else:
                print("use_test_batch")
                # begin training
                valid_feed.epoch_init(config.batch_size, shuffle=False)

                if config.with_label_loss:
                    labeled_feed.epoch_init(len(labeled_dial), shuffle=False)
                else:
                    labeled_feed = None
                    labeled_dial_labels = None
                results, fetch_results = model.get_zt(
                    global_t,
                    sess,
                    valid_feed,
                    update_limit=config.update_limit,
                    labeled_feed=labeled_feed,
                    labeled_labels=labeled_dial_labels)
                with open(FLAGS.result_path, "w") as fh:
                    pkl.dump(results, fh)
                with open(FLAGS.result_path + ".param.pkl", "w") as fh:
                    pkl.dump(fetch_results, fh)
            """
Exemplo n.º 10
0
def main():
    # config for training
    config = Config()
    print("Normal train config:")
    # pp(config)

    valid_config = Config()
    valid_config.dropout = 0
    valid_config.batch_size = 20

    # config for test
    test_config = Config()
    test_config.dropout = 0
    test_config.batch_size = 1

    # LOG #
    if not os.path.isdir('./output'):
        os.makedirs('./output')
    if not os.path.isdir('./output/{}'.format(args.expname)):
        os.makedirs('./output/{}'.format(args.expname))

    cur_time = str(datetime.now().strftime('%Y%m%d%H%M'))
    # save arguments
    json.dump(
        vars(args),
        open('./output/{}/{}_args.json'.format(args.expname, cur_time), 'w'))

    logger = logging.getLogger(__name__)
    logging.basicConfig(level=logging.DEBUG, format="%(message)s")
    fh = logging.FileHandler("./output/{}/logs_{}.txt".format(
        args.expname, cur_time))
    # add the handlers to the logger
    logger.addHandler(fh)
    logger.info(vars(args))

    ###############################################################################
    # Load data
    ###############################################################################
    # sentiment data path:  ../ final_data / poem_with_sentiment.txt
    # 该path必须命令行显示输入LoadPoem,因为defaultNone
    # 处理pretrain数据和完整诗歌数据
    api = LoadPoem(args.train_data_dir, args.test_data_dir,
                   args.max_vocab_size)

    # 交替训练,准备大数据集
    poem_corpus = api.get_poem_corpus()  # corpus for training and validation
    test_data = api.get_test_corpus()  # 测试数据
    # 三个list,每个list中的每一个元素都是 [topic, last_sentence, current_sentence]
    train_poem, valid_poem, test_poem = poem_corpus.get(
        "train"), poem_corpus.get("valid"), test_data.get("test")

    train_loader = SWDADataLoader("Train", train_poem, config)
    valid_loader = SWDADataLoader("Valid", valid_poem, config)
    test_loader = SWDADataLoader("Test", test_poem, config)

    print("Finish Poem data loading, not pretraining or alignment test")

    if not args.forward_only:
        ###############################################################################
        # Define the models and word2vec weight
        ###############################################################################
        # 处理用四库全书训练的word2vec
        # if args.model != "Seq2Seq"

        # logger.info("Start loading siku word2vec")
        # pretrain_weight = None
        # if os.path.exists(args.word2vec_path):
        #     pretrain_vec = {}
        #     word2vec = open(args.word2vec_path)
        #     pretrain_data = word2vec.read().split('\n')[1:]
        #     for data in pretrain_data:
        #         data = data.split(' ')
        #         pretrain_vec[data[0]] = [float(item) for item in data[1:-1]]
        #     # nparray (vocab_len, emb_dim)
        #     pretrain_weight = process_pretrain_vec(pretrain_vec, api.vocab)
        #     logger.info("Successfully loaded siku word2vec")

        # import pdb
        # pdb.set_trace()

        # 无论是否pretrain,都使用高斯混合模型
        # pretrain时,用特定数据训练特定的高斯分布
        # 不用pretrain时,用大数据训练高斯混合分布

        if args.model == "Seq2Seq":
            model = Seq2Seq(config=config, api=api)
        else:
            model = PoemWAE(config=config, api=api)

        if use_cuda:
            model = model.cuda()
        # if corpus.word2vec is not None and args.reload_from<0:
        #     print("Loaded word2vec")
        #     model.embedder.weight.data.copy_(torch.from_numpy(corpus.word2vec))
        #     model.embedder.weight.data[0].fill_(0)

        ###############################################################################
        # Start training
        ###############################################################################
        # model依然是PoemWAE_GMP保持不变,只不过,用这部分数据强制训练其中一个高斯先验分布
        # pretrain = True

        tb_writer = SummaryWriter(
            "./output/{}/{}/{}/logs/".format(args.model, args.expname, args.dataset)\
            + datetime.now().strftime('%Y%m%d%H%M')) if args.visual else None

        global_iter = 1
        cur_best_score = {
            'min_valid_loss': 100,
            'min_global_itr': 0,
            'min_epoch': 0,
            'min_itr': 0
        }

        train_loader.epoch_init(config.batch_size, shuffle=True)

        # model = load_model(3, 3)
        batch_idx = 0
        while global_iter < 100:
            batch_idx = 0
            while True:  # loop through all batches in training data
                # train一个batch
                model, finish_train, loss_records = \
                    train_process(model=model, train_loader=train_loader, config=config, sentiment_data=False)

                batch_idx += 1
                if finish_train:
                    test_process(model=model,
                                 test_loader=test_loader,
                                 test_config=test_config,
                                 logger=logger)
                    evaluate_process(model=model,
                                     valid_loader=valid_loader,
                                     global_iter=global_iter,
                                     epoch=global_iter,
                                     logger=logger,
                                     tb_writer=tb_writer,
                                     api=api)
                    # save model after each epoch
                    save_model(model=model,
                               epoch=global_iter,
                               global_iter=global_iter,
                               batch_idx=batch_idx)
                    logger.info(
                        'Finish epoch %d, current min valid loss: %.4f \
                     correspond global_itr: %d  epoch: %d  itr: %d \n\n' %
                        (global_iter, cur_best_score['min_valid_loss'],
                         cur_best_score['min_global_itr'],
                         cur_best_score['min_epoch'],
                         cur_best_score['min_itr']))
                    # 初始化下一个unlabeled data epoch的训练
                    # unlabeled_epoch += 1
                    train_loader.epoch_init(config.batch_size, shuffle=True)
                    break
                # elif batch_idx >= start_batch + config.n_batch_every_iter:
                #     print("Finish unlabel epoch %d batch %d to %d" %
                #           (unlabeled_epoch, start_batch, start_batch + config.n_batch_every_iter))
                #     start_batch += config.n_batch_every_iter
                #     break

                # 写一下log
                if batch_idx % (train_loader.num_batch // 50) == 0:
                    log = 'Global iter %d: step: %d/%d: ' \
                          % (global_iter, batch_idx, train_loader.num_batch)
                    for loss_name, loss_value in loss_records:
                        log = log + loss_name + ':%.4f ' % loss_value
                        if args.visual:
                            tb_writer.add_scalar(loss_name, loss_value,
                                                 global_iter)
                    logger.info(log)

                # valid
                if batch_idx % (train_loader.num_batch // 10) == 0:
                    valid_process(
                        model=model,
                        valid_loader=valid_loader,
                        valid_config=valid_config,
                        global_iter=global_iter,
                        unlabeled_epoch=
                        global_iter,  # 如果sample_rate_unlabeled不是1,这里要在最后加一个1
                        batch_idx=batch_idx,
                        tb_writer=tb_writer,
                        logger=logger,
                        cur_best_score=cur_best_score)
                    test_process(model=model,
                                 test_loader=test_loader,
                                 test_config=test_config,
                                 logger=logger)
                    save_model(model=model,
                               epoch=global_iter,
                               global_iter=global_iter,
                               batch_idx=batch_idx)
                # if batch_idx % (train_loader.num_batch // 3) == 0:
                #     test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger)

            global_iter += 1

    # forward_only 测试
    else:
        # test_global_list = [4, 4, 2]
        # test_epoch_list = [21, 19, 8]
        test_global_list = [8]
        test_epoch_list = [20]
        for i in range(1):
            # import pdb
            # pdb.set_trace()
            model = load_model('./output/basic/header_model.pckl')
            model.vocab = api.vocab
            model.rev_vocab = api.rev_vocab
            test_loader.epoch_init(test_config.batch_size, shuffle=False)

            last_title = None
            while True:
                model.eval()  # eval()主要影响BatchNorm, dropout等操作
                batch = get_user_input(api.rev_vocab, config.title_size)

                # batch = test_loader.next_batch_test()  # test data使用专门的batch
                # import pdb
                # pdb.set_trace()
                if batch is None:
                    break

                title_list, headers, title = batch  # batch size是1,一个batch写一首诗

                if title == last_title:
                    continue
                last_title = title

                title_tensor = to_tensor(title_list)

                # test函数将当前batch对应的这首诗decode出来,记住每次decode的输入context是上一次的结果
                output_poem = model.test(title_tensor=title_tensor,
                                         title_words=title_list,
                                         headers=headers)
                with open('./content_from_remote.txt', 'w') as file:
                    file.write(output_poem)
                print(output_poem)
                print('\n')
            print("Done testing")