def main(args): src, tgt = load_data(args.path) src_vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') src_vocab.load(os.path.join(args.path, 'vocab.en')) tgt_vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') tgt_vocab.load(os.path.join(args.path, 'vocab.de')) sos_idx = 0 eos_idx = 1 pad_idx = 2 max_length = 50 src_vocab_size = len(src_vocab) tgt_vocab_size = len(tgt_vocab) # Set hyper parameter device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = make_model(src_vocab_size, tgt_vocab_size).to(device) optimizer = get_std_opt(model) criterion = LabelSmoothing(size=tgt_vocab_size, padding_idx=pad_idx, smoothing=0.1) train_criterion = SimpleLossCompute(model.generator, criterion, optimizer) valid_criterion = SimpleLossCompute(model.generator, criterion, None) print('Using device:', device) if not args.test: train_loader = get_loader(src['train'], tgt['train'], src_vocab, tgt_vocab, batch_size=args.batch_size, shuffle=True) valid_loader = get_loader(src['valid'], tgt['valid'], src_vocab, tgt_vocab, batch_size=args.batch_size) best_loss = 987654321 for epoch in range(args.epochs): train_total_loss, valid_total_loss = 0.0, 0.0 start = time.time() total_tokens = 0 tokens = 0 model.train() # Train for src_batch, tgt_batch in train_loader: src_batch = torch.tensor(src_batch).to(device) tgt_batch = torch.tensor(tgt_batch).to(device) batch = Batch(src_batch, tgt_batch, pad_idx) prediction = model(batch.src, batch.trg, batch.src_mask, batch.trg_mask) loss = train_criterion(prediction, batch.trg_y, batch.ntokens) train_total_loss += loss total_tokens += batch.ntokens tokens += batch.ntokens # Valid model.eval() for src_batch, tgt_batch in valid_loader: src_batch = torch.tensor(src_batch).to(device) tgt_batch = torch.tensor(tgt_batch).to(device) batch = Batch(src_batch, tgt_batch, pad_idx) prediction = model(batch.src, batch.trg, batch.src_mask, batch.trg_mask) loss = valid_criterion(prediction, batch.trg_y, batch.ntokens) valid_total_loss += loss total_tokens += batch.ntokens tokens += batch.ntokens if valid_total_loss.item() < best_loss: best_loss = valid_total_loss best_model_state = model.state_dict() best_optimizer_state = optimizer.optimizer.state_dict() elpsed = time.time() - start print( time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + "|| [" + str(epoch) + "/" + str(args.epochs) + "], train_loss = " + str(train_total_loss.item()) + ", valid_loss = " + str(valid_total_loss.item()) + ", Tokens per Sec = " + str(tokens.item() / elpsed)) tokens = 0 start = time.time() if epoch % 100 == 0: # Save model torch.save( { 'epoch': args.epochs, 'model_state_dict': best_model_state, 'optimizer_state': best_optimizer_state, 'loss': best_loss }, args.model_dir + "/intermediate.pt") print("Model saved") # Save model torch.save( { 'epoch': args.epochs, 'model_state_dict': best_model_state, 'optimizer_state': best_optimizer_state, 'loss': best_loss }, args.model_dir + "/best.pt") print("Model saved") else: # Load the model checkpoint = torch.load(args.model_dir + "/" + args.model_name, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) optimizer.optimizer.load_state_dict(checkpoint['optimizer_state']) model.eval() print("Model loaded") # Test test_loader = get_loader(src['test'], tgt['test'], src_vocab, tgt_vocab, batch_size=args.batch_size) pred = [] for src_batch, tgt_batch in test_loader: src_batch = torch.tensor(src_batch).to(device) tgt_batch = torch.tensor(tgt_batch).to(device) batch = Batch(src_batch, tgt_batch, pad_idx) # Get pred_batch memory = model.encode(batch.src, batch.src_mask) pred_batch = torch.ones(src_batch.size(0), 1)\ .fill_(sos_idx).type_as(batch.src.data).to(device) for i in range(max_length - 1): out = model.decode( memory, batch.src_mask, Variable(pred_batch), Variable( Batch.make_std_mask(pred_batch, pad_idx).type_as(batch.src.data))) prob = model.generator(out[:, -1]) prob.index_fill_(1, torch.tensor([sos_idx, pad_idx]).to(device), -float('inf')) _, next_word = torch.max(prob, dim=1) pred_batch = torch.cat( [pred_batch, next_word.unsqueeze(-1)], dim=1) pred_batch = torch.cat([pred_batch, torch.ones(src_batch.size(0), 1)\ .fill_(eos_idx).type_as(batch.src.data).to(device)], dim=1) # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1). # every <pad> token (index: 2) should be located after <eos> token (index: 1). # example of pred_batch: # [[0, 5, 6, 7, 1], # [0, 4, 9, 1, 2], # [0, 6, 1, 2, 2]] pred += seq2sen(pred_batch.tolist(), tgt_vocab) with open('results/pred.txt', 'w', encoding='utf-8') as f: for line in pred: f.write('{}\n'.format(line)) os.system( 'bash scripts/bleu.sh results/pred.txt multi30k/test.de.atok')
def main(args): # 0. initial setting # set environmet cudnn.benchmark = True if not os.path.isdir('./ckpt'): os.mkdir('./ckpt') if not os.path.isdir('./results'): os.mkdir('./results') if not os.path.isdir(os.path.join('./ckpt', args.name)): os.mkdir(os.path.join('./ckpt', args.name)) if not os.path.isdir(os.path.join('./results', args.name)): os.mkdir(os.path.join('./results', args.name)) if not os.path.isdir(os.path.join('./results', args.name, "log")): os.mkdir(os.path.join('./results', args.name, "log")) # set logger logger = logging.getLogger() logger.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(message)s') handler = logging.FileHandler("results/{}/log/{}.log".format( args.name, time.strftime('%c', time.localtime(time.time())))) handler.setFormatter(formatter) logger.addHandler(handler) logger.addHandler(logging.StreamHandler()) args.logger = logger # set cuda if torch.cuda.is_available(): args.logger.info("running on cuda") args.device = torch.device("cuda") args.use_cuda = True else: args.logger.info("running on cpu") args.device = torch.device("cpu") args.use_cuda = False args.logger.info("[{}] starts".format(args.name)) # 1. load data args.logger.info("loading data...") src, tgt = load_data(args.path) src_vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') src_vocab.load(os.path.join(args.path, 'vocab.en')) tgt_vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') tgt_vocab.load(os.path.join(args.path, 'vocab.de')) # 2. setup args.logger.info("setting up...") sos_idx = 0 eos_idx = 1 pad_idx = 2 max_length = 50 src_vocab_size = len(src_vocab) tgt_vocab_size = len(tgt_vocab) # transformer config d_e = 512 # embedding size d_q = 64 # query size (= key, value size) d_h = 2048 # hidden layer size in feed forward network num_heads = 8 num_layers = 6 # number of encoder/decoder layers in encoder/decoder args.sos_idx = sos_idx args.eos_idx = eos_idx args.pad_idx = pad_idx args.max_length = max_length args.src_vocab_size = src_vocab_size args.tgt_vocab_size = tgt_vocab_size args.d_e = d_e args.d_q = d_q args.d_h = d_h args.num_heads = num_heads args.num_layers = num_layers model = Transformer(args) model.to(args.device) loss_fn = nn.CrossEntropyLoss(ignore_index=pad_idx) optimizer = optim.Adam(model.parameters(), lr=1e-5) if args.load: model.load_state_dict(load(args, args.ckpt)) # 3. train / test if not args.test: # train args.logger.info("starting training") acc_val_meter = AverageMeter(name="Acc-Val (%)", save_all=True, save_dir=os.path.join( 'results', args.name)) train_loss_meter = AverageMeter(name="Loss", save_all=True, save_dir=os.path.join( 'results', args.name)) train_loader = get_loader(src['train'], tgt['train'], src_vocab, tgt_vocab, batch_size=args.batch_size, shuffle=True) valid_loader = get_loader(src['valid'], tgt['valid'], src_vocab, tgt_vocab, batch_size=args.batch_size) for epoch in range(1, 1 + args.epochs): spent_time = time.time() model.train() train_loss_tmp_meter = AverageMeter() for src_batch, tgt_batch in tqdm(train_loader): # src_batch: (batch x source_length), tgt_batch: (batch x target_length) optimizer.zero_grad() src_batch, tgt_batch = torch.LongTensor(src_batch).to( args.device), torch.LongTensor(tgt_batch).to(args.device) batch = src_batch.shape[0] # split target batch into input and output tgt_batch_i = tgt_batch[:, :-1] tgt_batch_o = tgt_batch[:, 1:] pred = model(src_batch.to(args.device), tgt_batch_i.to(args.device)) loss = loss_fn(pred.contiguous().view(-1, tgt_vocab_size), tgt_batch_o.contiguous().view(-1)) loss.backward() optimizer.step() train_loss_tmp_meter.update(loss / batch, weight=batch) train_loss_meter.update(train_loss_tmp_meter.avg) spent_time = time.time() - spent_time args.logger.info( "[{}] train loss: {:.3f} took {:.1f} seconds".format( epoch, train_loss_tmp_meter.avg, spent_time)) # validation model.eval() acc_val_tmp_meter = AverageMeter() spent_time = time.time() for src_batch, tgt_batch in tqdm(valid_loader): src_batch, tgt_batch = torch.LongTensor( src_batch), torch.LongTensor(tgt_batch) tgt_batch_i = tgt_batch[:, :-1] tgt_batch_o = tgt_batch[:, 1:] with torch.no_grad(): pred = model(src_batch.to(args.device), tgt_batch_i.to(args.device)) corrects, total = val_check( pred.max(dim=-1)[1].cpu(), tgt_batch_o) acc_val_tmp_meter.update(100 * corrects / total, total) spent_time = time.time() - spent_time args.logger.info( "[{}] validation accuracy: {:.1f} %, took {} seconds".format( epoch, acc_val_tmp_meter.avg, spent_time)) acc_val_meter.update(acc_val_tmp_meter.avg) if epoch % args.save_period == 0: save(args, "epoch_{}".format(epoch), model.state_dict()) acc_val_meter.save() train_loss_meter.save() else: # test args.logger.info("starting test") test_loader = get_loader(src['test'], tgt['test'], src_vocab, tgt_vocab, batch_size=args.batch_size) pred_list = [] model.eval() for src_batch, tgt_batch in test_loader: #src_batch: (batch x source_length) src_batch = torch.Tensor(src_batch).long().to(args.device) batch = src_batch.shape[0] pred_batch = torch.zeros(batch, 1).long().to(args.device) pred_mask = torch.zeros(batch, 1).bool().to( args.device) # mask whether each sentece ended up with torch.no_grad(): for _ in range(args.max_length): pred = model( src_batch, pred_batch) # (batch x length x tgt_vocab_size) pred[:, :, pad_idx] = -1 # ignore <pad> pred = pred.max(dim=-1)[1][:, -1].unsqueeze( -1) # next word prediction: (batch x 1) pred = pred.masked_fill( pred_mask, 2).long() # fill out <pad> for ended sentences pred_mask = torch.gt(pred.eq(1) + pred.eq(2), 0) pred_batch = torch.cat([pred_batch, pred], dim=1) if torch.prod(pred_mask) == 1: break pred_batch = torch.cat([ pred_batch, torch.ones(batch, 1).long().to(args.device) + pred_mask.long() ], dim=1) # close all sentences pred_list += seq2sen(pred_batch.cpu().numpy().tolist(), tgt_vocab) with open('results/pred.txt', 'w', encoding='utf-8') as f: for line in pred_list: f.write('{}\n'.format(line)) os.system( 'bash scripts/bleu.sh results/pred.txt multi30k/test.de.atok')
def main(args): src, tgt = load_data(args.path) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) src_vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') src_vocab.load(os.path.join(args.path, 'vocab.en')) tgt_vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') tgt_vocab.load(os.path.join(args.path, 'vocab.de')) sos_idx = 0 eos_idx = 1 pad_idx = 2 max_length = 50 src_vocab_size = len(src_vocab) tgt_vocab_size = len(tgt_vocab) N = 6 dim = 512 # MODEL Construction encoder = Encoder(N, dim, pad_idx, src_vocab_size, device).to(device) decoder = Decoder(N, dim, pad_idx, tgt_vocab_size, device).to(device) if args.model_load: ckpt = torch.load("drive/My Drive/checkpoint/best.ckpt") encoder.load_state_dict(ckpt["encoder"]) decoder.load_state_dict(ckpt["decoder"]) params = list(encoder.parameters()) + list(decoder.parameters()) if not args.test: train_loader = get_loader(src['train'], tgt['train'], src_vocab, tgt_vocab, batch_size=args.batch_size, shuffle=True) valid_loader = get_loader(src['valid'], tgt['valid'], src_vocab, tgt_vocab, batch_size=args.batch_size) warmup = 4000 steps = 1 lr = 1. * (dim**-0.5) * min(steps**-0.5, steps * (warmup**-1.5)) optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.98), eps=1e-09) train_losses = [] val_losses = [] latest = 1e08 # to store latest checkpoint start_epoch = 0 if (args.model_load): start_epoch = ckpt["epoch"] optimizer.load_state_dict(ckpt["optim"]) steps = start_epoch * 30 for epoch in range(start_epoch, args.epochs): for src_batch, tgt_batch in train_loader: encoder.train() decoder.train() optimizer.zero_grad() tgt_batch = torch.LongTensor(tgt_batch) src_batch = Variable(torch.LongTensor(src_batch)).to(device) gt = Variable(tgt_batch[:, 1:]).to(device) tgt_batch = Variable(tgt_batch[:, :-1]).to(device) enc_output, seq_mask = encoder(src_batch) dec_output = decoder(tgt_batch, enc_output, seq_mask) gt = gt.view(-1) dec_output = dec_output.view(gt.size()[0], -1) loss = F.cross_entropy(dec_output, gt, ignore_index=pad_idx) loss.backward() train_losses.append(loss.item()) optimizer.step() steps += 1 lr = (dim**-0.5) * min(steps**-0.5, steps * (warmup**-1.5)) update_lr(optimizer, lr) if (steps % 10 == 0): print("loss : %f" % loss.item()) for src_batch, tgt_batch in valid_loader: encoder.eval() decoder.eval() src_batch = Variable(torch.LongTensor(src_batch)).to(device) tgt_batch = torch.LongTensor(tgt_batch) gt = Variable(tgt_batch[:, 1:]).to(device) tgt_batch = Variable(tgt_batch[:, :-1]).to(device) enc_output, seq_mask = encoder(src_batch) dec_output = decoder(tgt_batch, enc_output, seq_mask) gt = gt.view(-1) dec_output = dec_output.view(gt.size()[0], -1) loss = F.cross_entropy(dec_output, gt, ignore_index=pad_idx) val_losses.append(loss.item()) print("[EPOCH %d] Loss %f" % (epoch, loss.item())) if (val_losses[-1] <= latest): checkpoint = {'encoder':encoder.state_dict(), 'decoder':decoder.state_dict(), \ 'optim':optimizer.state_dict(), 'epoch':epoch} torch.save(checkpoint, "drive/My Drive/checkpoint/best.ckpt") latest = val_losses[-1] if (epoch % 20 == 0): plt.figure() plt.plot(val_losses) plt.xlabel("epoch") plt.ylabel("model loss") plt.show() else: # test test_loader = get_loader(src['test'], tgt['test'], src_vocab, tgt_vocab, batch_size=args.batch_size) # LOAD CHECKPOINT pred = [] for src_batch, tgt_batch in test_loader: encoder.eval() decoder.eval() b_s = min(args.batch_size, len(src_batch)) tgt_batch = torch.zeros(b_s, 1).to(device).long() src_batch = Variable(torch.LongTensor(src_batch)).to(device) enc_output, seq_mask = encoder(src_batch) pred_batch = decoder(tgt_batch, enc_output, seq_mask) _, pred_batch = torch.max(pred_batch, 2) while (not is_finished(pred_batch, max_length, eos_idx)): # do something next_input = torch.cat((tgt_batch, pred_batch.long()), 1) pred_batch = decoder(next_input, enc_output, seq_mask) _, pred_batch = torch.max(pred_batch, 2) # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1). # every <pad> token (index: 2) should be located after <eos> token (index: 1). # example of pred_batch: # [[0, 5, 6, 7, 1], # [0, 4, 9, 1, 2], # [0, 6, 1, 2, 2]] pred_batch = pred_batch.tolist() for line in pred_batch: line[-1] = 1 pred += seq2sen(pred_batch, tgt_vocab) # print(pred) with open('results/pred.txt', 'w') as f: for line in pred: f.write('{}\n'.format(line)) os.system( 'bash scripts/bleu.sh results/pred.txt multi30k/test.de.atok')
def main(args): eos_idx = 1 pad_idx = -1 pad_val = 0.0 feature_dim = 40 listener_hidden_dim = 256 num_of_pyramidal_layers = 3 speller_hidden_dim = 512 attention_hidden_dim = 128 num_of_classes = 30 learning_rate = 0.2 geometric_decay = 0.98 device = torch.device("cuda" if(torch.cuda.is_available()) else "cpu") listener = Listener(feature_dim, listener_hidden_dim, num_of_pyramidal_layers).to(device) speller = Speller(speller_hidden_dim, listener_hidden_dim, attention_hidden_dim, num_of_classes, device).to(device) #print(device, listener, speller) if not args.test: # train src, trg = load_train_data(args.path) train_loader = DataLoader(src, trg, args.batch_size, pad_idx, pad_val) criterion = nn.CrossEntropyLoss(ignore_index = pad_idx) #optimizer = torch.optim.ASGD([{'params':listener.parameters()}, {'params':speller.parameters()}], lr=learning_rate) optimizer = torch.optim.Adam([{'params':listener.parameters()}, {'params':speller.parameters()}], lr = 0.0001) print('Start training ...') for epoch in range(args.epochs): start_epoch = time.time() i = 0 for src_batch, tgt_batch in train_loader: batch_start = time.time() src_batch = torch.tensor(src_batch).to(device) trg_batch = torch.tensor(tgt_batch).to(device) trg_input = trg_batch[:,:-1] trg_output = trg_batch[:,1:].contiguous().view(-1) h = listener(src_batch) preds = speller(trg_input, h) # lr decay for every 1/20 epoch #if (i+1) % ((train_loader.size//args.batch_size)//20) is 0 : # learning_rate = geometric_decay * learning_rate # print('learing rate decayed : %.4f'%(learning_rate)) # for group in optimizer.param_groups: # group['lr'] = learning_rate optimizer.zero_grad() loss = criterion(preds.view(-1, preds.size(-1)), trg_output) loss.backward() optimizer.step() i = i+1 # flush the GPU cache if torch.cuda.is_available(): torch.cuda.empty_cache() batch_time = time.time() - batch_start print('[%d/%d][%d/%d] train loss : %.4f | time : %.2fs'%(epoch+1, 100, i, train_loader.size//args.batch_size + 1, loss.item(), batch_time)) epoch_time = time.time() - start_epoch print('Time taken for %d epoch : %.2fs'%(epoch+1, epoch_time)) save_checkpoint(listener, speller, 'checkpoints/epoch_%d_'%(epoch+1)) print('End of the training') save_checkpoint(listener, speller, 'checkpoints/final_') else: if os.path.exists(args.checkpoint + 'listener') and os.path.exists(args.checkpoint + 'speller'): listener_checkpoint = torch.load(args.checkpoint + 'listener') listener.load_state_dict(listener_checkpoint['state_dict']) print("trained model " + args.checkpoint + "listener is loaded") speller_checkpoint = torch.load(args.checkpoint + 'speller') speller.load_state_dict(speller_checkpoint['state_dict']) print("trained model " + args.checkpoint + "speller is loaded") # test src, trg = load_test_data(args.path) test_loader = DataLoader(src, trg, args.batch_size, pad_idx, pad_val) mapping = get_mapping(args.path) j = 0 pred = [] ref = [] for src_batch, trg_batch in test_loader: # predict pred_batch from src_batch with your model. # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1). # every <pad> token (index: 2) should be located after <eos> token (index: 1). # example of pred_batch: # [[0, 5, 6, 7, 1], # [0, 4, 9, 1, 2], # [0, 6, 1, 2, 2]] start_batch = time.time() src_batch = torch.tensor(src_batch).to(device) trg_batch = torch.tensor(trg_batch).to(device) max_length = trg_batch.size(1) pred_batch = torch.zeros(args.batch_size, 1, dtype = int).to(device) # [batch, 1] = [[0],[0],...,[0]] # eos_mask[i] = 1 means i-th sentence has eos eos_mask = torch.zeros(args.batch_size, dtype = int) h = listener(src_batch) for k in range(max_length): start = time.time() output = speller(pred_batch, h) # [batch, k+1, num_class] # greedy search output = torch.argmax(F.softmax(output, dim = -1), dim = -1) # [batch_size, k+1] predictions = output[:,-1].unsqueeze(1) pred_batch = torch.cat([pred_batch, predictions], dim = -1) for i in range(args.batch_size): if predictions[i] == eos_idx: eos_mask[i] = 1 # every sentence has eos if eos_mask.sum() == args.batch_size : break t = time.time() - start print("[%d/%d][%d/%d] prediction done | time : %.2fs"%(j, test_loader.size // args.batch_size + 1, k+1, max_length, t)) j += 1 # flush the GPU cache if torch.cuda.is_available(): torch.cuda.empty_cache() print("[%d/%d] prediction done | time : %.2fs"%(j, test_loader.size // args.batch_size + 1, time.time() - start_batch)) pred += seq2sen(pred_batch.cpu().numpy().tolist(), mapping) ref += seq2sen(trg_batch.cpu().numpy().tolist(), mapping) if j % 10 == 0: WER = word_error_rate(ref, pred) print("Test [%d/%d] : WER %.2f%%"%(j, test_loader.size // args.batch_size + 1, WER)) with open('results/pred_%d.txt'%(j), 'w') as f: for line in pred: f.write('{}\n'.format(line)) with open('results/ref_%d.txt'%(j), 'w') as f: for line in ref: f.write('{}\n'.format(line)) WER = word_error_rate(ref, pred) print("Test End : WER %.2f%%"%(WER))
def main(args): src, tgt = load_data(args.path) src_vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') src_vocab.load(os.path.join(args.path, 'vocab.en')) tgt_vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') tgt_vocab.load(os.path.join(args.path, 'vocab.de')) vsize_src = len(src_vocab) vsize_tar = len(tgt_vocab) net = Transformer(vsize_src, vsize_tar) if not args.test: train_loader = get_loader(src['train'], tgt['train'], src_vocab, tgt_vocab, batch_size=args.batch_size, shuffle=True) valid_loader = get_loader(src['valid'], tgt['valid'], src_vocab, tgt_vocab, batch_size=args.batch_size) net.to(device) optimizer = optim.Adam(net.parameters(), lr=args.lr) best_valid_loss = 10.0 for epoch in range(args.epochs): print("Epoch {0}".format(epoch)) net.train() train_loss = run_epoch(net, train_loader, optimizer) print("train loss: {0}".format(train_loss)) net.eval() valid_loss = run_epoch(net, valid_loader, None) print("valid loss: {0}".format(valid_loss)) torch.save(net, 'data/ckpt/last_model') if valid_loss < best_valid_loss: best_valid_loss = valid_loss torch.save(net, 'data/ckpt/best_model') else: # test net = torch.load('data/ckpt/best_model') net.to(device) net.eval() test_loader = get_loader(src['test'], tgt['test'], src_vocab, tgt_vocab, batch_size=args.batch_size) pred = [] iter_cnt = 0 for src_batch, tgt_batch in test_loader: source, src_mask = make_tensor(src_batch) source = source.to(device) src_mask = src_mask.to(device) res = net.decode(source, src_mask) pred_batch = res.tolist() # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1). # every <pad> token (index: 2) should be located after <eos> token (index: 1). # example of pred_batch: # [[0, 5, 6, 7, 1], # [0, 4, 9, 1, 2], # [0, 6, 1, 2, 2]] pred += seq2sen(pred_batch, tgt_vocab) iter_cnt += 1 #print(pred_batch) with open('data/results/pred.txt', 'w') as f: for line in pred: f.write('{}\n'.format(line)) os.system( 'bash scripts/bleu.sh data/results/pred.txt data/multi30k/test.de.atok' )
def main(args): # constant definition sos_idx = 0 eos_idx = 1 pad_idx = 2 a_dim = 512 h_dim = 512 attn_dim = 512 embed_dim = 512 regularize_constant = 1. # lambda * L => lambda = 1/L vocabulary = torch.load(args.voca_path) vocab_size = len(vocabulary) device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu") encoder = Encoder().to(device) decoder = Decoder(a_dim, h_dim, attn_dim, vocab_size, embed_dim).to(device) # We do not train the encoder encoder.eval() if not args.test: # train validation_term = 1 best_bleu = 0. num_of_epochs_since_improvement = 0 early_stop_criterion = 20 train_loader = get_train_data_loader(args.path, args.token_path, args.voca_path, args.batch_size, pad_idx) valid_loader = get_test_data_loader(args.path, args.token_path, args.voca_path, args.batch_size, pad_idx, dataset_type='valid') criterion = nn.CrossEntropyLoss(ignore_index=pad_idx) optimizer = torch.optim.Adam(decoder.parameters(), lr=0.0001) print('Start training ...') for epoch in range(args.epochs): # early stopping if num_of_epochs_since_improvement > early_stop_criterion: print("There's no improvement on BLEU score while %d epochs" % (num_of_epochs_since_improvement)) print("Stop Training") break start_epoch = time.time() i = 0 ############################################################################################################################################ # training decoder.train() for src_batch, trg_batch in train_loader: batch_start = time.time() src_batch = src_batch.to(device) trg_batch = torch.tensor(trg_batch).to(device) trg_input = trg_batch[:, :-1] trg_output = trg_batch[:, 1:].contiguous().view(-1) a = encoder(src_batch) preds, alphas = decoder( a, trg_input) # [batch, C, vocab_size], [batch, C, L] optimizer.zero_grad() loss = criterion(preds.view(-1, preds.size(-1)), trg_output) # NLL loss regularize_term = regularize_constant * ( (1. - torch.sum(alphas, dim=1))**2).mean() total_loss = loss + regularize_term total_loss.backward() optimizer.step() i = i + 1 # flush the GPU cache if torch.cuda.is_available(): torch.cuda.empty_cache() batch_time = time.time() - batch_start print( '[%d/%d][%d/%d] train loss : %.4f (%.4f / %.4f) | time : %.2fs' % (epoch + 1, args.epochs, i, train_loader.size // args.batch_size + 1, total_loss.item(), loss.item(), regularize_term.item(), batch_time)) epoch_time = time.time() - start_epoch print('Time taken for %d epoch : %.2fs' % (epoch + 1, epoch_time)) ############################################################################################################################################ # validation if i % validation_term == 0: decoder.eval() j = 0 pred, ref = [], [] for src_batch, trg_batch in valid_loader: start = time.time() batch_size = src_batch.size(0) src_batch = src_batch.to(device) # [batch, 3, 244, 244] trg_batch = torch.tensor(trg_batch).to( device) # [batch * 5, C] trg_batch = torch.split(trg_batch, 5) batches = [] for k in range(batch_size): batches.append(trg_batch[k].unsqueeze(0)) trg_batch = torch.cat(batches, dim=0) # [batch, 5, C] max_length = trg_batch.size(-1) pred_batch = torch.zeros(batch_size, 1, dtype=int).to( device) # [batch, 1] = [[0],[0],...,[0]] # eos_mask[i] = 1 means i-th sentence has eos eos_mask = torch.zeros(batch_size, dtype=int) a = encoder(src_batch) for _ in range(max_length): output, _ = decoder( a, pred_batch) # [batch, _+1, vocab_size] # greedy search output = torch.argmax(F.softmax(output, dim=-1), dim=-1) # [batch_size, _+1] predictions = output[:, -1].unsqueeze(1) pred_batch = torch.cat([pred_batch, predictions], dim=-1) for l in range(batch_size): if predictions[l] == eos_idx: eos_mask[l] = 1 # every sentence has eos if eos_mask.sum() == batch_size: break # flush the GPU cache if torch.cuda.is_available(): torch.cuda.empty_cache() pred += seq2sen(pred_batch.cpu().numpy().tolist(), vocabulary) for m in range(batch_size): ref += [ seq2sen(trg_batch[m].cpu().numpy().tolist(), vocabulary) ] t = time.time() - start j += 1 print("[%d/%d] prediction done | time : %.2fs" % (j, valid_loader.size // args.batch_size + 1, t)) bleu_1 = corpus_bleu(ref, pred, weights=(1. / 1., )) * 100 bleu_2 = corpus_bleu(ref, pred, weights=( 1. / 2., 1. / 2., )) * 100 bleu_3 = corpus_bleu( ref, pred, weights=( 1. / 3., 1. / 3., 1. / 3., )) * 100 bleu_4 = corpus_bleu( ref, pred, weights=( 1. / 4., 1. / 4., 1. / 4., 1. / 4., )) * 100 print(f'BLEU-1: {bleu_1:.2f}') print(f'BLEU-2: {bleu_2:.2f}') print(f'BLEU-3: {bleu_3:.2f}') print(f'BLEU-4: {bleu_4:.2f}') if bleu_1 > best_bleu: num_of_epochs_since_improvement = 0 best_bleu = bleu_1 print('Best BLEU-1 has been updated : %.2f' % (best_bleu)) save_checkpoint(decoder, 'checkpoints/best') else: num_of_epochs_since_improvement += validation_term print( "There's no improvement on BLEU score while %d epochs" % (num_of_epochs_since_improvement)) ################################################################################################################################################################ print('End of the training') else: if os.path.exists(args.checkpoint): decoder_checkpoint = torch.load(args.checkpoint) decoder.load_state_dict(decoder_checkpoint['state_dict']) print("trained decoder " + args.checkpoint + " is loaded") decoder.eval() # test test_loader = get_test_data_loader(args.path, args.token_path, args.voca_path, args.batch_size, pad_idx) j = 0 pred, ref = [], [] for src_batch, trg_batch in test_loader: # predict pred_batch from src_batch with your model. # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1). # every <pad> token (index: 2) should be located after <eos> token (index: 1). # example of pred_batch: # [[0, 5, 6, 7, 1], # [0, 4, 9, 1, 2], # [0, 6, 1, 2, 2]] start = time.time() batch_size = src_batch.size(0) src_batch = src_batch.to(device) # [batch, 3, 244, 244] trg_batch = torch.tensor(trg_batch).to(device) # [batch * 5, C] trg_batch = torch.split(trg_batch, 5) batches = [] for k in range(batch_size): batches.append(trg_batch[k].unsqueeze(0)) trg_batch = torch.cat(batches, dim=0) # [batch, 5, C] max_length = trg_batch.size(-1) pred_batch = torch.zeros(batch_size, 1, dtype=int).to( device) # [batch, 1] = [[0],[0],...,[0]] # eos_mask[i] = 1 means i-th sentence has eos eos_mask = torch.zeros(batch_size, dtype=int) a = encoder(src_batch) for _ in range(max_length): output, _ = decoder(a, pred_batch) # [batch, _+1, vocab_size] # greedy search output = torch.argmax(F.softmax(output, dim=-1), dim=-1) # [batch_size, _+1] predictions = output[:, -1].unsqueeze(1) pred_batch = torch.cat([pred_batch, predictions], dim=-1) for l in range(batch_size): if predictions[l] == eos_idx: eos_mask[l] = 1 # every sentence has eos if eos_mask.sum() == batch_size: break # flush the GPU cache if torch.cuda.is_available(): torch.cuda.empty_cache() pred += seq2sen(pred_batch.cpu().numpy().tolist(), vocabulary) for m in range(batch_size): ref += [ seq2sen(trg_batch[m].cpu().numpy().tolist(), vocabulary) ] t = time.time() - start j += 1 print("[%d/%d] prediction done | time : %.2fs" % (j, test_loader.size // args.batch_size + 1, t)) bleu_1 = corpus_bleu(ref, pred, weights=(1. / 1., )) * 100 bleu_2 = corpus_bleu(ref, pred, weights=( 1. / 2., 1. / 2., )) * 100 bleu_3 = corpus_bleu(ref, pred, weights=( 1. / 3., 1. / 3., 1. / 3., )) * 100 bleu_4 = corpus_bleu( ref, pred, weights=( 1. / 4., 1. / 4., 1. / 4., 1. / 4., )) * 100 print(f'BLEU-1: {bleu_1:.2f}') print(f'BLEU-2: {bleu_2:.2f}') print(f'BLEU-3: {bleu_3:.2f}') print(f'BLEU-4: {bleu_4:.2f}') with open('results/pred.txt', 'w') as f: for line in pred: f.write('{}\n'.format(line)) with open('results/ref.txt', 'w') as f: for lines in ref: for line in lines: f.write('{}\n'.format(line)) f.write('_' * 50 + '\n')
def main(args): src, tgt = load_data(args.path) src_vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') src_vocab.load(os.path.join(args.path, 'vocab.en')) tgt_vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') tgt_vocab.load(os.path.join(args.path, 'vocab.de')) # TODO: use these information. sos_idx = 0 eos_idx = 1 pad_idx = 2 max_length = 50 # TODO: use these values to construct embedding layers src_vocab_size = len(src_vocab) tgt_vocab_size = len(tgt_vocab) if not args.test: train_loader = get_loader(src['train'], tgt['train'], src_vocab, tgt_vocab, batch_size=args.batch_size, shuffle=True) valid_loader = get_loader(src['valid'], tgt['valid'], src_vocab, tgt_vocab, batch_size=args.batch_size) # TODO: train for epoch in range(args.epochs): for src_batch, tgt_batch in train_loader: pass # TODO: validation for src_batch, tgt_batch in valid_loader: pass else: # test test_loader = get_loader(src['test'], tgt['test'], src_vocab, tgt_vocab, batch_size=args.batch_size) pred = [] for src_batch, tgt_batch in test_loader: # TODO: predict pred_batch from src_batch with your model. pred_batch = tgt_batch # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1). # every <pad> token (index: 2) should be located after <eos> token (index: 1). # example of pred_batch: # [[0, 5, 6, 7, 1], # [0, 4, 9, 1, 2], # [0, 6, 1, 2, 2]] pred += seq2sen(pred_batch, tgt_vocab) with open('results/pred.txt', 'w') as f: for line in pred: f.write('{}\n'.format(line)) os.system( 'bash scripts/bleu.sh results/pred.txt multi30k/test.de.atok')