def main(): # config for training config = Config() print("Normal train config:") pp(config) valid_config = Config() valid_config.dropout = 0 valid_config.batch_size = 20 # config for test test_config = Config() test_config.dropout = 0 test_config.batch_size = 1 with_sentiment = config.with_sentiment ############################################################################### # Load data ############################################################################### # sentiment data path: ../ final_data / poem_with_sentiment.txt # 该path必须命令行显示输入LoadPoem,因为defaultNonehjk # 处理pretrain数据和完整诗歌数据 # api = LoadPoem(args.train_data_dir, args.test_data_dir, args.max_vocab_size) api = LoadPoem(corpus_path=args.train_data_dir, test_path=args.test_data_dir, max_vocab_cnt=config.max_vocab_cnt, with_sentiment=with_sentiment) # 交替训练,准备大数据集 poem_corpus = api.get_tokenized_poem_corpus( type=1 + int(with_sentiment)) # corpus for training and validation test_data = api.get_tokenized_test_corpus() # 测试数据 # 三个list,每个list中的每一个元素都是 [topic, last_sentence, current_sentence] train_poem, valid_poem, test_poem = poem_corpus["train"], poem_corpus[ "valid"], test_data["test"] train_loader = SWDADataLoader("Train", train_poem, config) valid_loader = SWDADataLoader("Valid", valid_poem, config) test_loader = SWDADataLoader("Test", test_poem, config) print("Finish Poem data loading, not pretraining or alignment test") if not args.forward_only: # LOG # log_start_time = str(datetime.now().strftime('%Y%m%d%H%M')) if not os.path.isdir('./output'): os.makedirs('./output') if not os.path.isdir('./output/{}'.format(args.expname)): os.makedirs('./output/{}'.format(args.expname)) if not os.path.isdir('./output/{}/{}'.format(args.expname, log_start_time)): os.makedirs('./output/{}/{}'.format(args.expname, log_start_time)) # save arguments json.dump( vars(args), open( './output/{}/{}/args.json'.format(args.expname, log_start_time), 'w')) logger = logging.getLogger(__name__) logging.basicConfig(level=logging.DEBUG, format="%(message)s") fh = logging.FileHandler("./output/{}/{}/logs.txt".format( args.expname, log_start_time)) # add the handlers to the logger logger.addHandler(fh) logger.info(vars(args)) tb_writer = SummaryWriter("./output/{}/{}/tb_logs".format( args.expname, log_start_time)) if args.visual else None if config.reload_model: model = load_model(config.model_name) else: if args.model == "mCVAE": model = CVAE_GMP(config=config, api=api) elif args.model == 'CVAE': model = CVAE(config=config, api=api) else: model = Seq2Seq(config=config, api=api) if use_cuda: model = model.cuda() # if corpus.word2vec is not None and args.reload_from<0: # print("Loaded word2vec") # model.embedder.weight.data.copy_(torch.from_numpy(corpus.word2vec)) # model.embedder.weight.data[0].fill_(0) ############################################################################### # Start training ############################################################################### # model依然是PoemWAE_GMP保持不变,只不过,用这部分数据强制训练其中一个高斯先验分布 # pretrain = True cur_best_score = { 'min_valid_loss': 100, 'min_global_itr': 0, 'min_epoch': 0, 'min_itr': 0 } train_loader.epoch_init(config.batch_size, shuffle=True) # model = load_model(3, 3) epoch_id = 0 global_t = 0 while epoch_id < config.epochs: while True: # loop through all batches in training data # train一个batch model, finish_train, loss_records, global_t = \ train_process(global_t=global_t, model=model, train_loader=train_loader, config=config, sentiment_data=with_sentiment) if finish_train: test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger) # evaluate_process(model=model, valid_loader=valid_loader, log_start_time=log_start_time, global_t=global_t, epoch=epoch_id, logger=logger, tb_writer=tb_writer, api=api) # save model after each epoch save_model(model=model, epoch=epoch_id, global_t=global_t, log_start_time=log_start_time) logger.info( 'Finish epoch %d, current min valid loss: %.4f \ correspond epoch: %d itr: %d \n\n' % (cur_best_score['min_valid_loss'], cur_best_score['min_global_itr'], cur_best_score['min_epoch'], cur_best_score['min_itr'])) # 初始化下一个unlabeled data epoch的训练 # unlabeled_epoch += 1 epoch_id += 1 train_loader.epoch_init(config.batch_size, shuffle=True) break # elif batch_idx >= start_batch + config.n_batch_every_iter: # print("Finish unlabel epoch %d batch %d to %d" % # (unlabeled_epoch, start_batch, start_batch + config.n_batch_every_iter)) # start_batch += config.n_batch_every_iter # break # 写一下log if global_t % config.log_every == 0: log = 'Epoch id %d: step: %d/%d: ' \ % (epoch_id, global_t % train_loader.num_batch, train_loader.num_batch) for loss_name, loss_value in loss_records: if loss_name == 'avg_lead_loss': continue log = log + loss_name + ':%.4f ' % loss_value if args.visual: tb_writer.add_scalar(loss_name, loss_value, global_t) logger.info(log) # valid if global_t % config.valid_every == 0: # test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger) valid_process( global_t=global_t, model=model, valid_loader=valid_loader, valid_config=valid_config, unlabeled_epoch= epoch_id, # 如果sample_rate_unlabeled不是1,这里要在最后加一个1 tb_writer=tb_writer, logger=logger, cur_best_score=cur_best_score) # if batch_idx % (train_loader.num_batch // 3) == 0: # test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger) if global_t % config.test_every == 0: test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger) # forward_only 测试 else: expname = 'sentInput' time = '202101191105' model = load_model( './output/{}/{}/model_global_t_13596_epoch3.pckl'.format( expname, time)) test_loader.epoch_init(test_config.batch_size, shuffle=False) if not os.path.exists('./output/{}/{}/test/'.format(expname, time)): os.mkdir('./output/{}/{}/test/'.format(expname, time)) output_file = [ open('./output/{}/{}/test/output_0.txt'.format(expname, time), 'w'), open('./output/{}/{}/test/output_1.txt'.format(expname, time), 'w'), open('./output/{}/{}/test/output_2.txt'.format(expname, time), 'w') ] poem_count = 0 predict_results = {0: [], 1: [], 2: []} titles = {0: [], 1: [], 2: []} sentiment_result = {0: [], 1: [], 2: []} # Get all poem predictions while True: model.eval() batch = test_loader.next_batch_test() # test data使用专门的batch poem_count += 1 if poem_count % 10 == 0: print("Predicted {} poems".format(poem_count)) if batch is None: break title_list = batch # batch size是1,一个batch写一首诗 title_tensor = to_tensor(title_list) # test函数将当前batch对应的这首诗decode出来,记住每次decode的输入context是上一次的结果 for i in range(3): sentiment_label = np.zeros(1, dtype=np.int64) sentiment_label[0] = int(i) sentiment_label = to_tensor(sentiment_label) output_poem, output_tokens = model.test( title_tensor, title_list, sentiment_label=sentiment_label) titles[i].append(output_poem.strip().split('\n')[0]) predict_results[i] += (np.array(output_tokens)[:, :7].tolist()) # Predict sentiment use the sort net from collections import defaultdict neg = defaultdict(int) neu = defaultdict(int) pos = defaultdict(int) total = defaultdict(int) for i in range(3): _, neg[i], neu[i], pos[i] = test_sentiment(predict_results[i]) total[i] = neg[i] + neu[i] + pos[i] for i in range(3): print("%d%%\t%d%%\t%d%%" % (neg * 100 / total, neu * 100 / total, pos * 100 / total)) for i in range(3): write_predict_result_to_file(titles[i], predict_results[i], sentiment_result[i], output_file[i]) output_file[i].close() print("Done testing")
def forward(self, inputs): # inputs [B, max_lenth] positions = self.token2position(inputs) positions_encoded = self.position_encoding(positions) return positions_encoded if __name__ == '__main__': from configs import Config import pickle as pk vocab = pk.load(open('Predictor/Utils/sogou_vocab.pkl', 'rb')) args = Config() args.sos_id = vocab.token2id['<BOS>'] args.batch_size = 1 print(args.sos_id) matrix = vocab.matrix transformer = Transformer(args, matrix) mm = t.nn.DataParallel(transformer).cuda() # output = transformer(inputs) # output2 = transformer(inputs) mm.load_state_dict(t.load('ckpt/20180913_233530/saved_models/2018_09_16_18_31_10T0.6108602118195541/model')) from torch.utils.data import Dataset, DataLoader from DataSets import DataSet from DataSets import own_collate_fn from Predictor.Utils import batch_scorer train_set = DataSet(args.sog_processed + 'train/') dev_set = DataSet(args.sog_processed + 'dev/') test_set = DataSet(args.sog_processed + 'test/')
def main(): # config for training config = Config() print("Normal train config:") pp(config) valid_config = Config() valid_config.dropout = 0 valid_config.batch_size = 20 # config for test test_config = Config() test_config.dropout = 0 test_config.batch_size = 1 with_sentiment = config.with_sentiment pretrain = False ############################################################################### # Logs ############################################################################### log_start_time = str(datetime.now().strftime('%Y%m%d%H%M')) if not os.path.isdir('./output'): os.makedirs('./output') if not os.path.isdir('./output/{}'.format(args.expname)): os.makedirs('./output/{}'.format(args.expname)) if not os.path.isdir('./output/{}/{}'.format(args.expname, log_start_time)): os.makedirs('./output/{}/{}'.format(args.expname, log_start_time)) # save arguments json.dump( vars(args), open('./output/{}/{}/args.json'.format(args.expname, log_start_time), 'w')) logger = logging.getLogger(__name__) logging.basicConfig(level=logging.DEBUG, format="%(message)s") fh = logging.FileHandler("./output/{}/{}/logs.txt".format( args.expname, log_start_time)) # add the handlers to the logger logger.addHandler(fh) logger.info(vars(args)) tb_writer = SummaryWriter("./output/{}/{}/tb_logs".format( args.expname, log_start_time)) if args.visual else None ############################################################################### # Model ############################################################################### # vocab and rev_vocab with open(args.vocab_path) as vocab_file: vocab = vocab_file.read().strip().split('\n') rev_vocab = {vocab[idx]: idx for idx in range(len(vocab))} if not pretrain: pass # assert config.reload_model # model = load_model(config.model_name) else: if args.model == "multiVAE": model = multiVAE(config=config, vocab=vocab, rev_vocab=rev_vocab) else: model = CVAE(config=config, vocab=vocab, rev_vocab=rev_vocab) if use_cuda: model = model.cuda() ############################################################################### # Load data ############################################################################### if pretrain: from collections import defaultdict api = LoadPretrainPoem(corpus_path=args.pretrain_data_dir, vocab_path="data/vocab.txt") train_corpus, valid_corpus = defaultdict(list), defaultdict(list) divide = 50000 train_corpus['pos'], valid_corpus['pos'] = api.data[ 'pos'][:divide], api.data['pos'][divide:] train_corpus['neu'], valid_corpus['neu'] = api.data[ 'neu'][:divide], api.data['neu'][divide:] train_corpus['neg'], valid_corpus['neg'] = api.data[ 'neg'][:divide], api.data['neg'][divide:] token_corpus = defaultdict(dict) token_corpus['pos'], token_corpus['neu'], token_corpus['neg'] = \ api.get_tokenized_poem_corpus(train_corpus['pos'], valid_corpus['pos']), \ api.get_tokenized_poem_corpus(train_corpus['neu'], valid_corpus['neu']), \ api.get_tokenized_poem_corpus(train_corpus['neg'], valid_corpus['neg']), # train_loader_dict = {'pos': } train_loader = { 'pos': SWDADataLoader("Train", token_corpus['pos']['train'], config), 'neu': SWDADataLoader("Train", token_corpus['neu']['train'], config), 'neg': SWDADataLoader("Train", token_corpus['neg']['train'], config) } valid_loader = { 'pos': SWDADataLoader("Train", token_corpus['pos']['valid'], config), 'neu': SWDADataLoader("Train", token_corpus['neu']['valid'], config), 'neg': SWDADataLoader("Train", token_corpus['neg']['valid'], config) } ############################################################################### # Pretrain three VAEs ############################################################################### epoch_id = 0 global_t = 0 init_train_loaders(train_loader, config) while epoch_id < config.epochs: while True: # loop through all batches in training data # train一个batch model, finish_train, loss_records, global_t = \ pre_train_process(global_t=global_t, model=model, train_loader=train_loader) if finish_train: if epoch_id > 5: save_model(model=model, epoch=epoch_id, global_t=global_t, log_start_time=log_start_time) epoch_id += 1 init_train_loaders(train_loader, config) break # 写一下log if global_t % config.log_every == 0: pre_log_process(epoch_id=epoch_id, global_t=global_t, train_loader=train_loader, loss_records=loss_records, logger=logger, tb_writer=tb_writer) # valid if global_t % config.valid_every == 0: # test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger) pre_valid_process(global_t=global_t, model=model, valid_loader=valid_loader, valid_config=valid_config, tb_writer=tb_writer, logger=logger) if global_t % config.test_every == 0: pre_test_process(model=model, logger=logger) ############################################################################### # Train the big model ############################################################################### api = LoadPoem(corpus_path=args.train_data_dir, vocab_path="data/vocab.txt", test_path=args.test_data_dir, max_vocab_cnt=config.max_vocab_cnt, with_sentiment=with_sentiment) from collections import defaultdict token_corpus = defaultdict(dict) token_corpus['pos'], token_corpus['neu'], token_corpus['neg'] = \ api.get_tokenized_poem_corpus(api.train_corpus['pos'], api.valid_corpus['pos']), \ api.get_tokenized_poem_corpus(api.train_corpus['neu'], api.valid_corpus['neu']), \ api.get_tokenized_poem_corpus(api.train_corpus['neg'], api.valid_corpus['neg']), train_loader = { 'pos': SWDADataLoader("Train", token_corpus['pos']['train'], config), 'neu': SWDADataLoader("Train", token_corpus['neu']['train'], config), 'neg': SWDADataLoader("Train", token_corpus['neg']['train'], config) } valid_loader = { 'pos': SWDADataLoader("Train", token_corpus['pos']['valid'], config), 'neu': SWDADataLoader("Train", token_corpus['neu']['valid'], config), 'neg': SWDADataLoader("Train", token_corpus['neg']['valid'], config) } test_poem = api.get_tokenized_test_corpus()['test'] # 测试数据 test_loader = SWDADataLoader("Test", test_poem, config) print("Finish Poem data loading, not pretraining or alignment test") if not args.forward_only: # model依然是PoemWAE_GMP保持不变,只不过,用这部分数据强制训练其中一个高斯先验分布 # pretrain = True cur_best_score = { 'min_valid_loss': 100, 'min_global_itr': 0, 'min_epoch': 0, 'min_itr': 0 } # model = load_model(3, 3) epoch_id = 0 global_t = 0 init_train_loaders(train_loader, config) while epoch_id < config.epochs: while True: # loop through all batches in training data # train一个batch model, finish_train, loss_records, global_t = \ train_process(global_t=global_t, model=model, train_loader=train_loader) if finish_train: if epoch_id > 5: save_model(model=model, epoch=epoch_id, global_t=global_t, log_start_time=log_start_time) epoch_id += 1 init_train_loaders(train_loader, config) break # 写一下log if global_t % config.log_every == 0: pre_log_process(epoch_id=epoch_id, global_t=global_t, train_loader=train_loader, loss_records=loss_records, logger=logger, tb_writer=tb_writer) # valid if global_t % config.valid_every == 0: valid_process(global_t=global_t, model=model, valid_loader=valid_loader, valid_config=valid_config, tb_writer=tb_writer, logger=logger) # if batch_idx % (train_loader.num_batch // 3) == 0: # test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger) if global_t % config.test_every == 0: test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger) # forward_only 测试 else: expname = 'trainVAE' time = '202101231631' model = load_model( './output/{}/{}/model_global_t_26250_epoch9.pckl'.format( expname, time)) test_loader.epoch_init(test_config.batch_size, shuffle=False) if not os.path.exists('./output/{}/{}/test/'.format(expname, time)): os.mkdir('./output/{}/{}/test/'.format(expname, time)) output_file = [ open('./output/{}/{}/test/output_0.txt'.format(expname, time), 'w'), open('./output/{}/{}/test/output_1.txt'.format(expname, time), 'w'), open('./output/{}/{}/test/output_2.txt'.format(expname, time), 'w') ] poem_count = 0 predict_results = {0: [], 1: [], 2: []} titles = {0: [], 1: [], 2: []} sentiment_result = {0: [], 1: [], 2: []} # sent_dict = {0: ['0', '1', '1', '0'], 1: ['2', '1', '2', '2'], 2: ['1', '0', '1', '2']} sent_dict = { 0: ['0', '0', '0', '0'], 1: ['1', '1', '1', '1'], 2: ['2', '2', '2', '2'] } # Get all poem predictions while True: model.eval() batch = test_loader.next_batch_test() # test data使用专门的batch poem_count += 1 if poem_count % 10 == 0: print("Predicted {} poems".format(poem_count)) if batch is None: break title_list = batch # batch size是1,一个batch写一首诗 title_tensor = to_tensor(title_list) # test函数将当前batch对应的这首诗decode出来,记住每次decode的输入context是上一次的结果 for i in range(3): sent_labels = sent_dict[i] for _ in range(4): sent_labels.append(str(i)) output_poem, output_tokens = model.test( title_tensor, title_list, sent_labels=sent_labels) titles[i].append(output_poem.strip().split('\n')[0]) predict_results[i] += (np.array(output_tokens)[:, :7].tolist()) # Predict sentiment use the sort net from collections import defaultdict neg = defaultdict(int) neu = defaultdict(int) pos = defaultdict(int) total = defaultdict(int) for i in range(3): cur_sent_result, neg[i], neu[i], pos[i] = test_sentiment( predict_results[i]) sentiment_result[i] = cur_sent_result total[i] = neg[i] + neu[i] + pos[i] for i in range(3): print("%d%%\t%d%%\t%d%%" % (neg[i] * 100 / total[i], neu[i] * 100 / total[i], pos[i] * 100 / total[i])) for i in range(3): write_predict_result_to_file(titles[i], predict_results[i], sentiment_result[i], output_file[i]) output_file[i].close() print("Done testing")
def default_detection_configs(phi, min_level=3, max_level=7, fpn_filters=64, neck_repeats=3, head_repeats=3, anchor_scale=4, num_scales=3, batch_size=4, image_size=512, fusion_type="weighted_sum"): h = Config() # model name h.detector = "efficientdet-d%d" % phi h.min_level = min_level h.max_level = max_level h.dtype = "float16" # backbone h.backbone = dict(backbone="efficientnet-b%d" % phi, convolution="depthwise_conv2d", dropblock=None, # dropblock=dict(keep_prob=None, # block_size=None) normalization=dict(normalization="batch_norm", momentum=0.99, epsilon=1e-3, axis=-1, trainable=False), activation=dict(activation="swish"), strides=[2, 1, 2, 2, 2, 1, 2, 1], dilation_rates=[1, 1, 1, 1, 1, 1, 1, 1], output_indices=[3, 4, 5], frozen_stages=[-1]) # neck h.neck = dict(neck="bifpn", repeats=neck_repeats, convolution="separable_conv2d", dropblock=None, # dropblock=dict(keep_prob=None, # block_size=None) feat_dims=fpn_filters, normalization=dict(normalization="batch_norm", momentum=0.99, epsilon=1e-3, axis=-1, trainable=False), activation=dict(activation="swish"), add_extra_conv=False, # Add extra convolution for neck fusion_type=fusion_type, use_multiplication=False) # head h.head = dict(head="RetinaNetHead", repeats=head_repeats, convolution="separable_conv2d", dropblock=None, # dropblock=dict(keep_prob=None, # block_size=None) feat_dims=fpn_filters, normalization=dict(normalization="batch_norm", momentum=0.99, epsilon=1e-3, axis=-1, trainable=False), activation=dict(activation="swish"), prior=0.01) # anchors parameters strides = [2 ** l for l in range(min_level, max_level + 1)] h.anchor = dict(aspect_ratios=[[1., 0.5, 2.]] * (max_level - min_level + 1), scales=[ [2 ** (i / num_scales) * s * anchor_scale for i in range(num_scales)] for s in strides ], num_anchors=9) # assigner h.assigner = dict(assigner="max_iou_assigner", pos_iou_thresh=0.5, neg_iou_thresh=0.5) # sampler h.sampler = dict(sampler="pseudo_sampler") # loss h.use_sigmoid = True h.label_loss=dict(loss="focal_loss", alpha=0.25, gamma=1.5, label_smoothing=0., weight=1., from_logits=True, reduction="none") h.bbox_loss=dict(loss="smooth_l1_loss", weight=50., # 50. delta=.1, # .1 reduction="none") # h.box_loss=dict(loss="giou_loss", # weight=10., # reduction="none") h.weight_decay = 4e-5 h.bbox_mean = None # [0., 0., 0., 0.] h.bbox_std = None # [0.1, 0.1, 0.2, 0.2] # dataset h.num_classes = 90 h.skip_crowd_during_training = True h.dataset = "objects365" h.batch_size = batch_size h.input_size = [image_size, image_size] h.train_dataset_dir = "/home/bail/Data/data1/Dataset/Objects365/train" h.val_dataset_dir = "/home/bail/Data/data1/Dataset/Objects365/train" h.augmentation = [ dict(ssd_crop=dict(patch_area_range=(0.3, 1.), aspect_ratio_range=(0.5, 2.0), min_overlaps=(0.1, 0.3, 0.5, 0.7, 0.9), max_attempts=100, probability=.5)), # dict(data_anchor_sampling=dict(anchor_scales=(16, 32, 64, 128, 256, 512), # overlap_threshold=0.7, # max_attempts=50, # probability=.5)), dict(flip_left_to_right=dict(probability=0.5)), dict(random_distort_color=dict(probability=1.)) ] # train h.pretrained_weights_path = "/home/bail/Workspace/pretrained_weights/efficientdet-d%d" % phi h.optimizer = dict(optimizer="sgd", momentum=0.9) h.lookahead = None h.train_steps = 240000 h.learning_rate_scheduler = dict(scheduler="cosine", initial_learning_rate=0.002) h.warmup = dict(warmup_learning_rate = 0.00001, steps = 24000) h.checkpoint_dir = "checkpoints/efficientdet_d%d" % phi h.summary_dir = "logs/efficientdet_d%d" % phi h.gradient_clip_norm = .0 h.log_every_n_steps = 500 h.save_ckpt_steps = 10000 h.val_every_n_steps = 4000 h.postprocess = dict(pre_nms_size=5000, # select top_k high confident detections for nms post_nms_size=100, iou_threshold=0.5, score_threshold=0.2) return h
def main(): # config for training config = Config() print("Normal train config:") # pp(config) valid_config = Config() valid_config.dropout = 0 valid_config.batch_size = 20 # config for test test_config = Config() test_config.dropout = 0 test_config.batch_size = 1 # LOG # if not os.path.isdir('./output'): os.makedirs('./output') if not os.path.isdir('./output/{}'.format(args.expname)): os.makedirs('./output/{}'.format(args.expname)) cur_time = str(datetime.now().strftime('%Y%m%d%H%M')) # save arguments json.dump( vars(args), open('./output/{}/{}_args.json'.format(args.expname, cur_time), 'w')) logger = logging.getLogger(__name__) logging.basicConfig(level=logging.DEBUG, format="%(message)s") fh = logging.FileHandler("./output/{}/logs_{}.txt".format( args.expname, cur_time)) # add the handlers to the logger logger.addHandler(fh) logger.info(vars(args)) ############################################################################### # Load data ############################################################################### # sentiment data path: ../ final_data / poem_with_sentiment.txt # 该path必须命令行显示输入LoadPoem,因为defaultNone # 处理pretrain数据和完整诗歌数据 api = LoadPoem(args.train_data_dir, args.test_data_dir, args.max_vocab_size) # 交替训练,准备大数据集 poem_corpus = api.get_poem_corpus() # corpus for training and validation test_data = api.get_test_corpus() # 测试数据 # 三个list,每个list中的每一个元素都是 [topic, last_sentence, current_sentence] train_poem, valid_poem, test_poem = poem_corpus.get( "train"), poem_corpus.get("valid"), test_data.get("test") train_loader = SWDADataLoader("Train", train_poem, config) valid_loader = SWDADataLoader("Valid", valid_poem, config) test_loader = SWDADataLoader("Test", test_poem, config) print("Finish Poem data loading, not pretraining or alignment test") if not args.forward_only: ############################################################################### # Define the models and word2vec weight ############################################################################### # 处理用四库全书训练的word2vec # if args.model != "Seq2Seq" # logger.info("Start loading siku word2vec") # pretrain_weight = None # if os.path.exists(args.word2vec_path): # pretrain_vec = {} # word2vec = open(args.word2vec_path) # pretrain_data = word2vec.read().split('\n')[1:] # for data in pretrain_data: # data = data.split(' ') # pretrain_vec[data[0]] = [float(item) for item in data[1:-1]] # # nparray (vocab_len, emb_dim) # pretrain_weight = process_pretrain_vec(pretrain_vec, api.vocab) # logger.info("Successfully loaded siku word2vec") # import pdb # pdb.set_trace() # 无论是否pretrain,都使用高斯混合模型 # pretrain时,用特定数据训练特定的高斯分布 # 不用pretrain时,用大数据训练高斯混合分布 if args.model == "Seq2Seq": model = Seq2Seq(config=config, api=api) else: model = PoemWAE(config=config, api=api) if use_cuda: model = model.cuda() # if corpus.word2vec is not None and args.reload_from<0: # print("Loaded word2vec") # model.embedder.weight.data.copy_(torch.from_numpy(corpus.word2vec)) # model.embedder.weight.data[0].fill_(0) ############################################################################### # Start training ############################################################################### # model依然是PoemWAE_GMP保持不变,只不过,用这部分数据强制训练其中一个高斯先验分布 # pretrain = True tb_writer = SummaryWriter( "./output/{}/{}/{}/logs/".format(args.model, args.expname, args.dataset)\ + datetime.now().strftime('%Y%m%d%H%M')) if args.visual else None global_iter = 1 cur_best_score = { 'min_valid_loss': 100, 'min_global_itr': 0, 'min_epoch': 0, 'min_itr': 0 } train_loader.epoch_init(config.batch_size, shuffle=True) # model = load_model(3, 3) batch_idx = 0 while global_iter < 100: batch_idx = 0 while True: # loop through all batches in training data # train一个batch model, finish_train, loss_records = \ train_process(model=model, train_loader=train_loader, config=config, sentiment_data=False) batch_idx += 1 if finish_train: test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger) evaluate_process(model=model, valid_loader=valid_loader, global_iter=global_iter, epoch=global_iter, logger=logger, tb_writer=tb_writer, api=api) # save model after each epoch save_model(model=model, epoch=global_iter, global_iter=global_iter, batch_idx=batch_idx) logger.info( 'Finish epoch %d, current min valid loss: %.4f \ correspond global_itr: %d epoch: %d itr: %d \n\n' % (global_iter, cur_best_score['min_valid_loss'], cur_best_score['min_global_itr'], cur_best_score['min_epoch'], cur_best_score['min_itr'])) # 初始化下一个unlabeled data epoch的训练 # unlabeled_epoch += 1 train_loader.epoch_init(config.batch_size, shuffle=True) break # elif batch_idx >= start_batch + config.n_batch_every_iter: # print("Finish unlabel epoch %d batch %d to %d" % # (unlabeled_epoch, start_batch, start_batch + config.n_batch_every_iter)) # start_batch += config.n_batch_every_iter # break # 写一下log if batch_idx % (train_loader.num_batch // 50) == 0: log = 'Global iter %d: step: %d/%d: ' \ % (global_iter, batch_idx, train_loader.num_batch) for loss_name, loss_value in loss_records: log = log + loss_name + ':%.4f ' % loss_value if args.visual: tb_writer.add_scalar(loss_name, loss_value, global_iter) logger.info(log) # valid if batch_idx % (train_loader.num_batch // 10) == 0: valid_process( model=model, valid_loader=valid_loader, valid_config=valid_config, global_iter=global_iter, unlabeled_epoch= global_iter, # 如果sample_rate_unlabeled不是1,这里要在最后加一个1 batch_idx=batch_idx, tb_writer=tb_writer, logger=logger, cur_best_score=cur_best_score) test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger) save_model(model=model, epoch=global_iter, global_iter=global_iter, batch_idx=batch_idx) # if batch_idx % (train_loader.num_batch // 3) == 0: # test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger) global_iter += 1 # forward_only 测试 else: # test_global_list = [4, 4, 2] # test_epoch_list = [21, 19, 8] test_global_list = [8] test_epoch_list = [20] for i in range(1): # import pdb # pdb.set_trace() model = load_model('./output/basic/header_model.pckl') model.vocab = api.vocab model.rev_vocab = api.rev_vocab test_loader.epoch_init(test_config.batch_size, shuffle=False) last_title = None while True: model.eval() # eval()主要影响BatchNorm, dropout等操作 batch = get_user_input(api.rev_vocab, config.title_size) # batch = test_loader.next_batch_test() # test data使用专门的batch # import pdb # pdb.set_trace() if batch is None: break title_list, headers, title = batch # batch size是1,一个batch写一首诗 if title == last_title: continue last_title = title title_tensor = to_tensor(title_list) # test函数将当前batch对应的这首诗decode出来,记住每次decode的输入context是上一次的结果 output_poem = model.test(title_tensor=title_tensor, title_words=title_list, headers=headers) with open('./content_from_remote.txt', 'w') as file: file.write(output_poem) print(output_poem) print('\n') print("Done testing")