Exemple #1
0
def main():
    # config for training
    config = Config()
    print("Normal train config:")
    pp(config)

    valid_config = Config()
    valid_config.dropout = 0
    valid_config.batch_size = 20

    # config for test
    test_config = Config()
    test_config.dropout = 0
    test_config.batch_size = 1

    with_sentiment = config.with_sentiment

    pretrain = False

    ###############################################################################
    # Logs
    ###############################################################################
    log_start_time = str(datetime.now().strftime('%Y%m%d%H%M'))
    if not os.path.isdir('./output'):
        os.makedirs('./output')
    if not os.path.isdir('./output/{}'.format(args.expname)):
        os.makedirs('./output/{}'.format(args.expname))
    if not os.path.isdir('./output/{}/{}'.format(args.expname,
                                                 log_start_time)):
        os.makedirs('./output/{}/{}'.format(args.expname, log_start_time))

    # save arguments
    json.dump(
        vars(args),
        open('./output/{}/{}/args.json'.format(args.expname, log_start_time),
             'w'))

    logger = logging.getLogger(__name__)
    logging.basicConfig(level=logging.DEBUG, format="%(message)s")
    fh = logging.FileHandler("./output/{}/{}/logs.txt".format(
        args.expname, log_start_time))
    # add the handlers to the logger
    logger.addHandler(fh)
    logger.info(vars(args))

    tb_writer = SummaryWriter("./output/{}/{}/tb_logs".format(
        args.expname, log_start_time)) if args.visual else None

    ###############################################################################
    # Model
    ###############################################################################
    # vocab and rev_vocab
    with open(args.vocab_path) as vocab_file:
        vocab = vocab_file.read().strip().split('\n')
        rev_vocab = {vocab[idx]: idx for idx in range(len(vocab))}

    if not pretrain:
        pass
        # assert config.reload_model
        # model = load_model(config.model_name)
    else:
        if args.model == "multiVAE":
            model = multiVAE(config=config, vocab=vocab, rev_vocab=rev_vocab)
        else:
            model = CVAE(config=config, vocab=vocab, rev_vocab=rev_vocab)
        if use_cuda:
            model = model.cuda()
    ###############################################################################
    # Load data
    ###############################################################################

    if pretrain:
        from collections import defaultdict
        api = LoadPretrainPoem(corpus_path=args.pretrain_data_dir,
                               vocab_path="data/vocab.txt")

        train_corpus, valid_corpus = defaultdict(list), defaultdict(list)
        divide = 50000
        train_corpus['pos'], valid_corpus['pos'] = api.data[
            'pos'][:divide], api.data['pos'][divide:]
        train_corpus['neu'], valid_corpus['neu'] = api.data[
            'neu'][:divide], api.data['neu'][divide:]
        train_corpus['neg'], valid_corpus['neg'] = api.data[
            'neg'][:divide], api.data['neg'][divide:]

        token_corpus = defaultdict(dict)
        token_corpus['pos'], token_corpus['neu'], token_corpus['neg'] = \
            api.get_tokenized_poem_corpus(train_corpus['pos'], valid_corpus['pos']), \
            api.get_tokenized_poem_corpus(train_corpus['neu'], valid_corpus['neu']), \
            api.get_tokenized_poem_corpus(train_corpus['neg'], valid_corpus['neg']),
        # train_loader_dict = {'pos': }

        train_loader = {
            'pos': SWDADataLoader("Train", token_corpus['pos']['train'],
                                  config),
            'neu': SWDADataLoader("Train", token_corpus['neu']['train'],
                                  config),
            'neg': SWDADataLoader("Train", token_corpus['neg']['train'],
                                  config)
        }

        valid_loader = {
            'pos': SWDADataLoader("Train", token_corpus['pos']['valid'],
                                  config),
            'neu': SWDADataLoader("Train", token_corpus['neu']['valid'],
                                  config),
            'neg': SWDADataLoader("Train", token_corpus['neg']['valid'],
                                  config)
        }
        ###############################################################################
        # Pretrain three VAEs
        ###############################################################################

        epoch_id = 0
        global_t = 0
        init_train_loaders(train_loader, config)
        while epoch_id < config.epochs:

            while True:  # loop through all batches in training data
                # train一个batch

                model, finish_train, loss_records, global_t = \
                    pre_train_process(global_t=global_t, model=model, train_loader=train_loader)
                if finish_train:
                    if epoch_id > 5:
                        save_model(model=model,
                                   epoch=epoch_id,
                                   global_t=global_t,
                                   log_start_time=log_start_time)
                    epoch_id += 1
                    init_train_loaders(train_loader, config)
                    break
                # 写一下log
                if global_t % config.log_every == 0:
                    pre_log_process(epoch_id=epoch_id,
                                    global_t=global_t,
                                    train_loader=train_loader,
                                    loss_records=loss_records,
                                    logger=logger,
                                    tb_writer=tb_writer)

                # valid
                if global_t % config.valid_every == 0:
                    # test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger)
                    pre_valid_process(global_t=global_t,
                                      model=model,
                                      valid_loader=valid_loader,
                                      valid_config=valid_config,
                                      tb_writer=tb_writer,
                                      logger=logger)
                if global_t % config.test_every == 0:
                    pre_test_process(model=model, logger=logger)
    ###############################################################################
    # Train the big model
    ###############################################################################
    api = LoadPoem(corpus_path=args.train_data_dir,
                   vocab_path="data/vocab.txt",
                   test_path=args.test_data_dir,
                   max_vocab_cnt=config.max_vocab_cnt,
                   with_sentiment=with_sentiment)
    from collections import defaultdict
    token_corpus = defaultdict(dict)
    token_corpus['pos'], token_corpus['neu'], token_corpus['neg'] = \
        api.get_tokenized_poem_corpus(api.train_corpus['pos'], api.valid_corpus['pos']), \
        api.get_tokenized_poem_corpus(api.train_corpus['neu'], api.valid_corpus['neu']), \
        api.get_tokenized_poem_corpus(api.train_corpus['neg'], api.valid_corpus['neg']),

    train_loader = {
        'pos': SWDADataLoader("Train", token_corpus['pos']['train'], config),
        'neu': SWDADataLoader("Train", token_corpus['neu']['train'], config),
        'neg': SWDADataLoader("Train", token_corpus['neg']['train'], config)
    }

    valid_loader = {
        'pos': SWDADataLoader("Train", token_corpus['pos']['valid'], config),
        'neu': SWDADataLoader("Train", token_corpus['neu']['valid'], config),
        'neg': SWDADataLoader("Train", token_corpus['neg']['valid'], config)
    }
    test_poem = api.get_tokenized_test_corpus()['test']  # 测试数据
    test_loader = SWDADataLoader("Test", test_poem, config)

    print("Finish Poem data loading, not pretraining or alignment test")

    if not args.forward_only:
        # model依然是PoemWAE_GMP保持不变,只不过,用这部分数据强制训练其中一个高斯先验分布
        # pretrain = True

        cur_best_score = {
            'min_valid_loss': 100,
            'min_global_itr': 0,
            'min_epoch': 0,
            'min_itr': 0
        }

        # model = load_model(3, 3)
        epoch_id = 0
        global_t = 0
        init_train_loaders(train_loader, config)
        while epoch_id < config.epochs:

            while True:  # loop through all batches in training data
                # train一个batch
                model, finish_train, loss_records, global_t = \
                    train_process(global_t=global_t, model=model, train_loader=train_loader)
                if finish_train:
                    if epoch_id > 5:
                        save_model(model=model,
                                   epoch=epoch_id,
                                   global_t=global_t,
                                   log_start_time=log_start_time)
                    epoch_id += 1
                    init_train_loaders(train_loader, config)
                    break

                # 写一下log
                if global_t % config.log_every == 0:
                    pre_log_process(epoch_id=epoch_id,
                                    global_t=global_t,
                                    train_loader=train_loader,
                                    loss_records=loss_records,
                                    logger=logger,
                                    tb_writer=tb_writer)

                # valid
                if global_t % config.valid_every == 0:
                    valid_process(global_t=global_t,
                                  model=model,
                                  valid_loader=valid_loader,
                                  valid_config=valid_config,
                                  tb_writer=tb_writer,
                                  logger=logger)
                # if batch_idx % (train_loader.num_batch // 3) == 0:
                #     test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger)
                if global_t % config.test_every == 0:
                    test_process(model=model,
                                 test_loader=test_loader,
                                 test_config=test_config,
                                 logger=logger)

        # forward_only 测试
    else:
        expname = 'trainVAE'
        time = '202101231631'

        model = load_model(
            './output/{}/{}/model_global_t_26250_epoch9.pckl'.format(
                expname, time))
        test_loader.epoch_init(test_config.batch_size, shuffle=False)
        if not os.path.exists('./output/{}/{}/test/'.format(expname, time)):
            os.mkdir('./output/{}/{}/test/'.format(expname, time))
        output_file = [
            open('./output/{}/{}/test/output_0.txt'.format(expname, time),
                 'w'),
            open('./output/{}/{}/test/output_1.txt'.format(expname, time),
                 'w'),
            open('./output/{}/{}/test/output_2.txt'.format(expname, time), 'w')
        ]
        poem_count = 0
        predict_results = {0: [], 1: [], 2: []}
        titles = {0: [], 1: [], 2: []}
        sentiment_result = {0: [], 1: [], 2: []}
        # sent_dict = {0: ['0', '1', '1', '0'], 1: ['2', '1', '2', '2'], 2: ['1', '0', '1', '2']}
        sent_dict = {
            0: ['0', '0', '0', '0'],
            1: ['1', '1', '1', '1'],
            2: ['2', '2', '2', '2']
        }
        # Get all poem predictions
        while True:
            model.eval()
            batch = test_loader.next_batch_test()  # test data使用专门的batch
            poem_count += 1
            if poem_count % 10 == 0:
                print("Predicted {} poems".format(poem_count))
            if batch is None:
                break
            title_list = batch  # batch size是1,一个batch写一首诗
            title_tensor = to_tensor(title_list)
            # test函数将当前batch对应的这首诗decode出来,记住每次decode的输入context是上一次的结果
            for i in range(3):
                sent_labels = sent_dict[i]
                for _ in range(4):
                    sent_labels.append(str(i))

                output_poem, output_tokens = model.test(
                    title_tensor, title_list, sent_labels=sent_labels)

                titles[i].append(output_poem.strip().split('\n')[0])
                predict_results[i] += (np.array(output_tokens)[:, :7].tolist())

        # Predict sentiment use the sort net
        from collections import defaultdict
        neg = defaultdict(int)
        neu = defaultdict(int)
        pos = defaultdict(int)
        total = defaultdict(int)
        for i in range(3):
            cur_sent_result, neg[i], neu[i], pos[i] = test_sentiment(
                predict_results[i])
            sentiment_result[i] = cur_sent_result
            total[i] = neg[i] + neu[i] + pos[i]

        for i in range(3):
            print("%d%%\t%d%%\t%d%%" % (neg[i] * 100 / total[i], neu[i] * 100 /
                                        total[i], pos[i] * 100 / total[i]))

        for i in range(3):
            write_predict_result_to_file(titles[i], predict_results[i],
                                         sentiment_result[i], output_file[i])
            output_file[i].close()

        print("Done testing")
Exemple #2
0
def main():

    # config for training
    config = Config()
    print("Normal train config:")
    pp(config)

    valid_config = Config()
    valid_config.dropout = 0
    valid_config.batch_size = 20

    # config for test
    test_config = Config()
    test_config.dropout = 0
    test_config.batch_size = 1

    with_sentiment = config.with_sentiment

    ###############################################################################
    # Load data
    ###############################################################################
    # sentiment data path:  ../ final_data / poem_with_sentiment.txt
    # 该path必须命令行显示输入LoadPoem,因为defaultNonehjk
    # 处理pretrain数据和完整诗歌数据

    # api = LoadPoem(args.train_data_dir, args.test_data_dir, args.max_vocab_size)
    api = LoadPoem(corpus_path=args.train_data_dir,
                   test_path=args.test_data_dir,
                   max_vocab_cnt=config.max_vocab_cnt,
                   with_sentiment=with_sentiment)

    # 交替训练,准备大数据集
    poem_corpus = api.get_tokenized_poem_corpus(
        type=1 + int(with_sentiment))  # corpus for training and validation
    test_data = api.get_tokenized_test_corpus()  # 测试数据
    # 三个list,每个list中的每一个元素都是 [topic, last_sentence, current_sentence]
    train_poem, valid_poem, test_poem = poem_corpus["train"], poem_corpus[
        "valid"], test_data["test"]

    train_loader = SWDADataLoader("Train", train_poem, config)
    valid_loader = SWDADataLoader("Valid", valid_poem, config)
    test_loader = SWDADataLoader("Test", test_poem, config)

    print("Finish Poem data loading, not pretraining or alignment test")

    if not args.forward_only:
        # LOG #
        log_start_time = str(datetime.now().strftime('%Y%m%d%H%M'))
        if not os.path.isdir('./output'):
            os.makedirs('./output')
        if not os.path.isdir('./output/{}'.format(args.expname)):
            os.makedirs('./output/{}'.format(args.expname))
        if not os.path.isdir('./output/{}/{}'.format(args.expname,
                                                     log_start_time)):
            os.makedirs('./output/{}/{}'.format(args.expname, log_start_time))

        # save arguments
        json.dump(
            vars(args),
            open(
                './output/{}/{}/args.json'.format(args.expname,
                                                  log_start_time), 'w'))

        logger = logging.getLogger(__name__)
        logging.basicConfig(level=logging.DEBUG, format="%(message)s")
        fh = logging.FileHandler("./output/{}/{}/logs.txt".format(
            args.expname, log_start_time))
        # add the handlers to the logger
        logger.addHandler(fh)
        logger.info(vars(args))

        tb_writer = SummaryWriter("./output/{}/{}/tb_logs".format(
            args.expname, log_start_time)) if args.visual else None

        if config.reload_model:
            model = load_model(config.model_name)
        else:
            if args.model == "mCVAE":
                model = CVAE_GMP(config=config, api=api)
            elif args.model == 'CVAE':
                model = CVAE(config=config, api=api)
            else:
                model = Seq2Seq(config=config, api=api)
            if use_cuda:
                model = model.cuda()

        # if corpus.word2vec is not None and args.reload_from<0:
        #     print("Loaded word2vec")
        #     model.embedder.weight.data.copy_(torch.from_numpy(corpus.word2vec))
        #     model.embedder.weight.data[0].fill_(0)

        ###############################################################################
        # Start training
        ###############################################################################
        # model依然是PoemWAE_GMP保持不变,只不过,用这部分数据强制训练其中一个高斯先验分布
        # pretrain = True

        cur_best_score = {
            'min_valid_loss': 100,
            'min_global_itr': 0,
            'min_epoch': 0,
            'min_itr': 0
        }

        train_loader.epoch_init(config.batch_size, shuffle=True)

        # model = load_model(3, 3)
        epoch_id = 0
        global_t = 0
        while epoch_id < config.epochs:

            while True:  # loop through all batches in training data
                # train一个batch
                model, finish_train, loss_records, global_t = \
                    train_process(global_t=global_t, model=model, train_loader=train_loader, config=config, sentiment_data=with_sentiment)
                if finish_train:
                    test_process(model=model,
                                 test_loader=test_loader,
                                 test_config=test_config,
                                 logger=logger)
                    # evaluate_process(model=model, valid_loader=valid_loader, log_start_time=log_start_time, global_t=global_t, epoch=epoch_id, logger=logger, tb_writer=tb_writer, api=api)
                    # save model after each epoch
                    save_model(model=model,
                               epoch=epoch_id,
                               global_t=global_t,
                               log_start_time=log_start_time)
                    logger.info(
                        'Finish epoch %d, current min valid loss: %.4f \
                     correspond epoch: %d  itr: %d \n\n' %
                        (cur_best_score['min_valid_loss'],
                         cur_best_score['min_global_itr'],
                         cur_best_score['min_epoch'],
                         cur_best_score['min_itr']))
                    # 初始化下一个unlabeled data epoch的训练
                    # unlabeled_epoch += 1
                    epoch_id += 1
                    train_loader.epoch_init(config.batch_size, shuffle=True)
                    break
                # elif batch_idx >= start_batch + config.n_batch_every_iter:
                #     print("Finish unlabel epoch %d batch %d to %d" %
                #           (unlabeled_epoch, start_batch, start_batch + config.n_batch_every_iter))
                #     start_batch += config.n_batch_every_iter
                #     break

                # 写一下log
                if global_t % config.log_every == 0:
                    log = 'Epoch id %d: step: %d/%d: ' \
                          % (epoch_id, global_t % train_loader.num_batch, train_loader.num_batch)
                    for loss_name, loss_value in loss_records:
                        if loss_name == 'avg_lead_loss':
                            continue
                        log = log + loss_name + ':%.4f ' % loss_value
                        if args.visual:
                            tb_writer.add_scalar(loss_name, loss_value,
                                                 global_t)
                    logger.info(log)

                # valid
                if global_t % config.valid_every == 0:
                    # test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger)
                    valid_process(
                        global_t=global_t,
                        model=model,
                        valid_loader=valid_loader,
                        valid_config=valid_config,
                        unlabeled_epoch=
                        epoch_id,  # 如果sample_rate_unlabeled不是1,这里要在最后加一个1
                        tb_writer=tb_writer,
                        logger=logger,
                        cur_best_score=cur_best_score)
                # if batch_idx % (train_loader.num_batch // 3) == 0:
                #     test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger)
                if global_t % config.test_every == 0:
                    test_process(model=model,
                                 test_loader=test_loader,
                                 test_config=test_config,
                                 logger=logger)

    # forward_only 测试
    else:
        expname = 'sentInput'
        time = '202101191105'

        model = load_model(
            './output/{}/{}/model_global_t_13596_epoch3.pckl'.format(
                expname, time))
        test_loader.epoch_init(test_config.batch_size, shuffle=False)
        if not os.path.exists('./output/{}/{}/test/'.format(expname, time)):
            os.mkdir('./output/{}/{}/test/'.format(expname, time))
        output_file = [
            open('./output/{}/{}/test/output_0.txt'.format(expname, time),
                 'w'),
            open('./output/{}/{}/test/output_1.txt'.format(expname, time),
                 'w'),
            open('./output/{}/{}/test/output_2.txt'.format(expname, time), 'w')
        ]

        poem_count = 0
        predict_results = {0: [], 1: [], 2: []}
        titles = {0: [], 1: [], 2: []}
        sentiment_result = {0: [], 1: [], 2: []}
        # Get all poem predictions
        while True:
            model.eval()
            batch = test_loader.next_batch_test()  # test data使用专门的batch
            poem_count += 1
            if poem_count % 10 == 0:
                print("Predicted {} poems".format(poem_count))
            if batch is None:
                break
            title_list = batch  # batch size是1,一个batch写一首诗
            title_tensor = to_tensor(title_list)
            # test函数将当前batch对应的这首诗decode出来,记住每次decode的输入context是上一次的结果
            for i in range(3):
                sentiment_label = np.zeros(1, dtype=np.int64)
                sentiment_label[0] = int(i)
                sentiment_label = to_tensor(sentiment_label)
                output_poem, output_tokens = model.test(
                    title_tensor, title_list, sentiment_label=sentiment_label)

                titles[i].append(output_poem.strip().split('\n')[0])
                predict_results[i] += (np.array(output_tokens)[:, :7].tolist())

        # Predict sentiment use the sort net
        from collections import defaultdict
        neg = defaultdict(int)
        neu = defaultdict(int)
        pos = defaultdict(int)
        total = defaultdict(int)
        for i in range(3):
            _, neg[i], neu[i], pos[i] = test_sentiment(predict_results[i])
            total[i] = neg[i] + neu[i] + pos[i]

        for i in range(3):
            print("%d%%\t%d%%\t%d%%" %
                  (neg * 100 / total, neu * 100 / total, pos * 100 / total))

        for i in range(3):
            write_predict_result_to_file(titles[i], predict_results[i],
                                         sentiment_result[i], output_file[i])
            output_file[i].close()

        print("Done testing")
Exemple #3
0
def main():
    # config for training
    config = Config()
    print("Normal train config:")
    # pp(config)

    valid_config = Config()
    valid_config.dropout = 0
    valid_config.batch_size = 20

    # config for test
    test_config = Config()
    test_config.dropout = 0
    test_config.batch_size = 1

    # LOG #
    if not os.path.isdir('./output'):
        os.makedirs('./output')
    if not os.path.isdir('./output/{}'.format(args.expname)):
        os.makedirs('./output/{}'.format(args.expname))

    cur_time = str(datetime.now().strftime('%Y%m%d%H%M'))
    # save arguments
    json.dump(
        vars(args),
        open('./output/{}/{}_args.json'.format(args.expname, cur_time), 'w'))

    logger = logging.getLogger(__name__)
    logging.basicConfig(level=logging.DEBUG, format="%(message)s")
    fh = logging.FileHandler("./output/{}/logs_{}.txt".format(
        args.expname, cur_time))
    # add the handlers to the logger
    logger.addHandler(fh)
    logger.info(vars(args))

    ###############################################################################
    # Load data
    ###############################################################################
    # sentiment data path:  ../ final_data / poem_with_sentiment.txt
    # 该path必须命令行显示输入LoadPoem,因为defaultNone
    # 处理pretrain数据和完整诗歌数据
    api = LoadPoem(args.train_data_dir, args.test_data_dir,
                   args.max_vocab_size)

    # 交替训练,准备大数据集
    poem_corpus = api.get_poem_corpus()  # corpus for training and validation
    test_data = api.get_test_corpus()  # 测试数据
    # 三个list,每个list中的每一个元素都是 [topic, last_sentence, current_sentence]
    train_poem, valid_poem, test_poem = poem_corpus.get(
        "train"), poem_corpus.get("valid"), test_data.get("test")

    train_loader = SWDADataLoader("Train", train_poem, config)
    valid_loader = SWDADataLoader("Valid", valid_poem, config)
    test_loader = SWDADataLoader("Test", test_poem, config)

    print("Finish Poem data loading, not pretraining or alignment test")

    if not args.forward_only:
        ###############################################################################
        # Define the models and word2vec weight
        ###############################################################################
        # 处理用四库全书训练的word2vec
        # if args.model != "Seq2Seq"

        # logger.info("Start loading siku word2vec")
        # pretrain_weight = None
        # if os.path.exists(args.word2vec_path):
        #     pretrain_vec = {}
        #     word2vec = open(args.word2vec_path)
        #     pretrain_data = word2vec.read().split('\n')[1:]
        #     for data in pretrain_data:
        #         data = data.split(' ')
        #         pretrain_vec[data[0]] = [float(item) for item in data[1:-1]]
        #     # nparray (vocab_len, emb_dim)
        #     pretrain_weight = process_pretrain_vec(pretrain_vec, api.vocab)
        #     logger.info("Successfully loaded siku word2vec")

        # import pdb
        # pdb.set_trace()

        # 无论是否pretrain,都使用高斯混合模型
        # pretrain时,用特定数据训练特定的高斯分布
        # 不用pretrain时,用大数据训练高斯混合分布

        if args.model == "Seq2Seq":
            model = Seq2Seq(config=config, api=api)
        else:
            model = PoemWAE(config=config, api=api)

        if use_cuda:
            model = model.cuda()
        # if corpus.word2vec is not None and args.reload_from<0:
        #     print("Loaded word2vec")
        #     model.embedder.weight.data.copy_(torch.from_numpy(corpus.word2vec))
        #     model.embedder.weight.data[0].fill_(0)

        ###############################################################################
        # Start training
        ###############################################################################
        # model依然是PoemWAE_GMP保持不变,只不过,用这部分数据强制训练其中一个高斯先验分布
        # pretrain = True

        tb_writer = SummaryWriter(
            "./output/{}/{}/{}/logs/".format(args.model, args.expname, args.dataset)\
            + datetime.now().strftime('%Y%m%d%H%M')) if args.visual else None

        global_iter = 1
        cur_best_score = {
            'min_valid_loss': 100,
            'min_global_itr': 0,
            'min_epoch': 0,
            'min_itr': 0
        }

        train_loader.epoch_init(config.batch_size, shuffle=True)

        # model = load_model(3, 3)
        batch_idx = 0
        while global_iter < 100:
            batch_idx = 0
            while True:  # loop through all batches in training data
                # train一个batch
                model, finish_train, loss_records = \
                    train_process(model=model, train_loader=train_loader, config=config, sentiment_data=False)

                batch_idx += 1
                if finish_train:
                    test_process(model=model,
                                 test_loader=test_loader,
                                 test_config=test_config,
                                 logger=logger)
                    evaluate_process(model=model,
                                     valid_loader=valid_loader,
                                     global_iter=global_iter,
                                     epoch=global_iter,
                                     logger=logger,
                                     tb_writer=tb_writer,
                                     api=api)
                    # save model after each epoch
                    save_model(model=model,
                               epoch=global_iter,
                               global_iter=global_iter,
                               batch_idx=batch_idx)
                    logger.info(
                        'Finish epoch %d, current min valid loss: %.4f \
                     correspond global_itr: %d  epoch: %d  itr: %d \n\n' %
                        (global_iter, cur_best_score['min_valid_loss'],
                         cur_best_score['min_global_itr'],
                         cur_best_score['min_epoch'],
                         cur_best_score['min_itr']))
                    # 初始化下一个unlabeled data epoch的训练
                    # unlabeled_epoch += 1
                    train_loader.epoch_init(config.batch_size, shuffle=True)
                    break
                # elif batch_idx >= start_batch + config.n_batch_every_iter:
                #     print("Finish unlabel epoch %d batch %d to %d" %
                #           (unlabeled_epoch, start_batch, start_batch + config.n_batch_every_iter))
                #     start_batch += config.n_batch_every_iter
                #     break

                # 写一下log
                if batch_idx % (train_loader.num_batch // 50) == 0:
                    log = 'Global iter %d: step: %d/%d: ' \
                          % (global_iter, batch_idx, train_loader.num_batch)
                    for loss_name, loss_value in loss_records:
                        log = log + loss_name + ':%.4f ' % loss_value
                        if args.visual:
                            tb_writer.add_scalar(loss_name, loss_value,
                                                 global_iter)
                    logger.info(log)

                # valid
                if batch_idx % (train_loader.num_batch // 10) == 0:
                    valid_process(
                        model=model,
                        valid_loader=valid_loader,
                        valid_config=valid_config,
                        global_iter=global_iter,
                        unlabeled_epoch=
                        global_iter,  # 如果sample_rate_unlabeled不是1,这里要在最后加一个1
                        batch_idx=batch_idx,
                        tb_writer=tb_writer,
                        logger=logger,
                        cur_best_score=cur_best_score)
                    test_process(model=model,
                                 test_loader=test_loader,
                                 test_config=test_config,
                                 logger=logger)
                    save_model(model=model,
                               epoch=global_iter,
                               global_iter=global_iter,
                               batch_idx=batch_idx)
                # if batch_idx % (train_loader.num_batch // 3) == 0:
                #     test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger)

            global_iter += 1

    # forward_only 测试
    else:
        # test_global_list = [4, 4, 2]
        # test_epoch_list = [21, 19, 8]
        test_global_list = [8]
        test_epoch_list = [20]
        for i in range(1):
            # import pdb
            # pdb.set_trace()
            model = load_model('./output/basic/header_model.pckl')
            model.vocab = api.vocab
            model.rev_vocab = api.rev_vocab
            test_loader.epoch_init(test_config.batch_size, shuffle=False)

            last_title = None
            while True:
                model.eval()  # eval()主要影响BatchNorm, dropout等操作
                batch = get_user_input(api.rev_vocab, config.title_size)

                # batch = test_loader.next_batch_test()  # test data使用专门的batch
                # import pdb
                # pdb.set_trace()
                if batch is None:
                    break

                title_list, headers, title = batch  # batch size是1,一个batch写一首诗

                if title == last_title:
                    continue
                last_title = title

                title_tensor = to_tensor(title_list)

                # test函数将当前batch对应的这首诗decode出来,记住每次decode的输入context是上一次的结果
                output_poem = model.test(title_tensor=title_tensor,
                                         title_words=title_list,
                                         headers=headers)
                with open('./content_from_remote.txt', 'w') as file:
                    file.write(output_poem)
                print(output_poem)
                print('\n')
            print("Done testing")