def train(): embed = None if args.embed_path is not None and os.path.exists(args.embed_path): print('Loading pretrained word embedding...') embed = {} with open(args.embed_path, 'r') as f: f.readline() for line in f.readlines(): line = line.strip().split() vec = [float(_) for _ in line[1:]] embed[line[0]] = vec vocab = Vocab(args, embed) print('Loading datasets...') train_data, val_data, test_data = [], [], [] fns = os.listdir(args.train_dir) fns.sort(key=lambda p: int(p.split('.')[0])) for fn in tqdm(fns): f = open(args.train_dir + fn, 'r') train_data.append(json.load(f)) f.close() vocab.add_sentence(train_data[-1]['reviewText'].split()) vocab.add_sentence(train_data[-1]['summary'].split()) vocab.add_user(train_data[-1]['userID']) vocab.add_product(train_data[-1]['productID']) fns = os.listdir(args.valid_dir) fns.sort(key=lambda p: int(p.split('.')[0])) for fn in tqdm(fns): f = open(args.valid_dir + fn, 'r') val_data.append(json.load(f)) f.close() vocab.add_sentence(val_data[-1]['reviewText'].split()) vocab.add_sentence(val_data[-1]['summary'].split()) vocab.add_user(val_data[-1]['userID']) vocab.add_product(val_data[-1]['productID']) fns = os.listdir(args.test_dir) fns.sort(key=lambda p: int(p.split('.')[0])) for fn in tqdm(fns): f = open(args.test_dir + fn, 'r') test_data.append(json.load(f)) f.close() vocab.add_sentence(test_data[-1]['reviewText'].split()) vocab.add_sentence(test_data[-1]['summary'].split()) vocab.add_user(test_data[-1]['userID']) vocab.add_product(test_data[-1]['productID']) print('Deleting rare words...') embed = vocab.trim() args.embed_num = len(embed) args.embed_dim = len(embed[0]) args.user_num = vocab.user_num args.product_num = vocab.product_num train_dataset = Dataset(train_data) val_dataset = Dataset(val_data) train_iter = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True) val_iter = DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False) net = EncoderDecoder(args, embed) if args.load_model is not None: print('Loading model...') checkpoint = torch.load(args.save_path + args.load_model) net = EncoderDecoder(checkpoint['args'], embed) net.load_state_dict(checkpoint['model']) if args.use_cuda: net.cuda() criterion = nn.NLLLoss(ignore_index=vocab.PAD_IDX, reduction='sum') optim = torch.optim.Adam(net.parameters(), lr=args.lr) print('Begin training...') for epoch in range(args.begin_epoch, args.epochs + 1): if epoch >= args.lr_decay_start: adjust_learning_rate(optim, epoch - args.lr_decay_start + 1) for i, batch in enumerate(train_iter): src, trg, src_embed, trg_embed, src_user, src_product, src_mask, src_lens, trg_lens, _1, _2 = vocab.read_batch( batch) pre_output = net(src, trg, src_embed, trg_embed, src_user, src_product, vocab.word_num, src_mask, src_lens, trg_lens) pre_output = torch.log( pre_output.view(-1, pre_output.size(-1)) + 1e-20) trg_output = trg.view(-1) loss = criterion(pre_output, trg_output) / len(src_lens) loss.backward() clip_grad_norm_(net.parameters(), args.max_norm) optim.step() optim.zero_grad() cnt = (epoch - 1) * len(train_iter) + i if cnt % args.print_every == 0: print('EPOCH [%d/%d]: BATCH_ID=[%d/%d] loss=%f' % (epoch, args.epochs, i, len(train_iter), loss.data)) if cnt % args.valid_every == 0: print('Begin valid... Epoch %d, Batch %d' % (epoch, i)) cur_loss, r1, r2, rl = evaluate(net, criterion, vocab, val_iter, True) save_path = args.save_path + 'valid_%d_%.4f_%.4f_%.4f_%.4f' % ( cnt / args.valid_every, cur_loss, r1, r2, rl) net.save(save_path) print( 'Epoch: %2d Cur_Val_Loss: %f Rouge-1: %f Rouge-2: %f Rouge-l: %f' % (epoch, cur_loss, r1, r2, rl)) return
def test(): embed = None if args.embed_path is not None and os.path.exists(args.embed_path): print('Loading pretrained word embedding...') embed = {} with open(args.embed_path, 'r') as f: f.readline() for line in f.readlines(): line = line.strip().split() vec = [float(_) for _ in line[1:]] embed[line[0]] = vec vocab = Vocab(args, embed) train_data, val_data, test_data = [], [], [] fns = os.listdir(args.train_dir) fns.sort(key=lambda p: int(p.split('.')[0])) for fn in tqdm(fns): f = open(args.train_dir + fn, 'r') train_data.append(json.load(f)) f.close() vocab.add_sentence(train_data[-1]['reviewText'].split()) vocab.add_sentence(train_data[-1]['summary'].split()) vocab.add_user(train_data[-1]['userID']) vocab.add_product(train_data[-1]['productID']) fns = os.listdir(args.valid_dir) fns.sort(key=lambda p: int(p.split('.')[0])) for fn in tqdm(fns): f = open(args.valid_dir + fn, 'r') val_data.append(json.load(f)) f.close() vocab.add_sentence(val_data[-1]['reviewText'].split()) vocab.add_sentence(val_data[-1]['summary'].split()) vocab.add_user(val_data[-1]['userID']) vocab.add_product(val_data[-1]['productID']) fns = os.listdir(args.test_dir) fns.sort(key=lambda p: int(p.split('.')[0])) for fn in tqdm(fns): f = open(args.test_dir + fn, 'r') test_data.append(json.load(f)) f.close() vocab.add_sentence(test_data[-1]['reviewText'].split()) vocab.add_sentence(test_data[-1]['summary'].split()) vocab.add_user(test_data[-1]['userID']) vocab.add_product(test_data[-1]['productID']) embed = vocab.trim() args.embed_num = len(embed) args.embed_dim = len(embed[0]) args.user_num = vocab.user_num args.product_num = vocab.product_num test_dataset = Dataset(test_data) test_iter = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False) print('Loading model...') checkpoint = torch.load(args.save_path + args.load_model) net = EncoderDecoder(checkpoint['args'], embed) net.load_state_dict(checkpoint['model']) if args.use_cuda: net.cuda() criterion = nn.NLLLoss(ignore_index=vocab.PAD_IDX, reduction='sum') print('Begin testing...') loss, r1, r2, rl = evaluate(net, criterion, vocab, test_iter, False) print('Loss: %f Rouge-1: %f Rouge-2: %f Rouge-l: %f' % (loss, r1, r2, rl))
def test_all(): embed = None if args.embed_path is not None and os.path.exists(args.embed_path): print('Loading pretrained word embedding...') embed = {} with open(args.embed_path, 'r') as f: f.readline() for line in f.readlines(): line = line.strip().split() vec = [float(_) for _ in line[1:]] embed[line[0]] = vec vocab = Vocab(args, embed) print('Loading datasets...') train_data, val_data, test_data = [], [], [] fns = os.listdir(args.train_dir) fns.sort(key=lambda p: int(p.split('.')[0])) for fn in tqdm(fns): f = open(args.train_dir + fn, 'r') train_data.append(json.load(f)) f.close() vocab.add_sentence(train_data[-1]['reviewText'].split()) vocab.add_sentence(train_data[-1]['summary'].split()) vocab.add_user(train_data[-1]['userID']) vocab.add_product(train_data[-1]['productID']) fns = os.listdir(args.valid_dir) fns.sort(key=lambda p: int(p.split('.')[0])) for fn in tqdm(fns): f = open(args.valid_dir + fn, 'r') val_data.append(json.load(f)) f.close() vocab.add_sentence(val_data[-1]['reviewText'].split()) vocab.add_sentence(val_data[-1]['summary'].split()) vocab.add_user(val_data[-1]['userID']) vocab.add_product(val_data[-1]['productID']) fns = os.listdir(args.test_dir) fns.sort(key=lambda p: int(p.split('.')[0])) for fn in tqdm(fns): f = open(args.test_dir + fn, 'r') test_data.append(json.load(f)) f.close() vocab.add_sentence(test_data[-1]['reviewText'].split()) vocab.add_sentence(test_data[-1]['summary'].split()) vocab.add_user(test_data[-1]['userID']) vocab.add_product(test_data[-1]['productID']) print('Deleting rare words...') embed = vocab.trim() args.embed_num = len(embed) args.embed_dim = len(embed[0]) args.user_num = vocab.user_num args.product_num = vocab.product_num test_dataset = Dataset(test_data) test_iter = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=my_collate) start, end = 10, 33 fns = os.listdir(args.save_path) f = open('result', 'w') for idx in range(start, end): fn = None for name in fns: if 'valid_%d_' % idx in name: fn = name break if fn is None: continue checkpoint = torch.load(args.save_path + fn) if args.model == 'linear': net = MemAttrLinear(checkpoint['args'], embed) else: net = MemAttrGate(checkpoint['args'], embed) net.load_state_dict(checkpoint['model']) if args.use_cuda: net.cuda() criterion = nn.NLLLoss(ignore_index=vocab.PAD_IDX, reduction='sum') loss, r1, r2, rl = evaluate(net, criterion, vocab, test_iter, train_data, True) f.write('Idx: %d Loss: %f Rouge-1: %f Rouge-2: %f Rouge-l: %f\n' % (idx, loss, r1, r2, rl))