class Flickr8k(VisionDataset): def __init__( self, args, root: str, ann_file: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: super(Flickr8k, self).__init__(root, transform=transform, target_transform=target_transform) self.annotations = torch.load(ann_file, map_location='cpu') self.ids = list(sorted(self.annotations.keys())) self.vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') sentences = [] for key in self.annotations.keys(): sentences += self.annotations[key] self.vocab.build_vocab(sentences) def __getitem__(self, index: int) -> Tuple[Any, Any]: img_id = self.ids[index] # Image img = Image.open(img_id).convert('RGB') if self.transform is not None: img = self.transform(img) # Captions target = self.annotations[img_id] if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self) -> int: return len(self.ids) def get_vocab(self): return self.vocab
def __init__( self, args, root: str, ann_file: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: super(Flickr8k, self).__init__(root, transform=transform, target_transform=target_transform) self.annotations = torch.load(ann_file, map_location='cpu') self.ids = list(sorted(self.annotations.keys())) self.vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') sentences = [] for key in self.annotations.keys(): sentences += self.annotations[key] self.vocab.build_vocab(sentences)
def main(args): # 0. initial setting # set environmet cudnn.benchmark = True if not os.path.isdir('./ckpt'): os.mkdir('./ckpt') if not os.path.isdir('./results'): os.mkdir('./results') if not os.path.isdir(os.path.join('./ckpt', args.name)): os.mkdir(os.path.join('./ckpt', args.name)) if not os.path.isdir(os.path.join('./results', args.name)): os.mkdir(os.path.join('./results', args.name)) if not os.path.isdir(os.path.join('./results', args.name, "log")): os.mkdir(os.path.join('./results', args.name, "log")) # set logger logger = logging.getLogger() logger.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(message)s') handler = logging.FileHandler("results/{}/log/{}.log".format( args.name, time.strftime('%c', time.localtime(time.time())))) handler.setFormatter(formatter) logger.addHandler(handler) logger.addHandler(logging.StreamHandler()) args.logger = logger # set cuda if torch.cuda.is_available(): args.logger.info("running on cuda") args.device = torch.device("cuda") args.use_cuda = True else: args.logger.info("running on cpu") args.device = torch.device("cpu") args.use_cuda = False args.logger.info("[{}] starts".format(args.name)) # 1. load data args.logger.info("loading data...") src, tgt = load_data(args.path) src_vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') src_vocab.load(os.path.join(args.path, 'vocab.en')) tgt_vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') tgt_vocab.load(os.path.join(args.path, 'vocab.de')) # 2. setup args.logger.info("setting up...") sos_idx = 0 eos_idx = 1 pad_idx = 2 max_length = 50 src_vocab_size = len(src_vocab) tgt_vocab_size = len(tgt_vocab) # transformer config d_e = 512 # embedding size d_q = 64 # query size (= key, value size) d_h = 2048 # hidden layer size in feed forward network num_heads = 8 num_layers = 6 # number of encoder/decoder layers in encoder/decoder args.sos_idx = sos_idx args.eos_idx = eos_idx args.pad_idx = pad_idx args.max_length = max_length args.src_vocab_size = src_vocab_size args.tgt_vocab_size = tgt_vocab_size args.d_e = d_e args.d_q = d_q args.d_h = d_h args.num_heads = num_heads args.num_layers = num_layers model = Transformer(args) model.to(args.device) loss_fn = nn.CrossEntropyLoss(ignore_index=pad_idx) optimizer = optim.Adam(model.parameters(), lr=1e-5) if args.load: model.load_state_dict(load(args, args.ckpt)) # 3. train / test if not args.test: # train args.logger.info("starting training") acc_val_meter = AverageMeter(name="Acc-Val (%)", save_all=True, save_dir=os.path.join( 'results', args.name)) train_loss_meter = AverageMeter(name="Loss", save_all=True, save_dir=os.path.join( 'results', args.name)) train_loader = get_loader(src['train'], tgt['train'], src_vocab, tgt_vocab, batch_size=args.batch_size, shuffle=True) valid_loader = get_loader(src['valid'], tgt['valid'], src_vocab, tgt_vocab, batch_size=args.batch_size) for epoch in range(1, 1 + args.epochs): spent_time = time.time() model.train() train_loss_tmp_meter = AverageMeter() for src_batch, tgt_batch in tqdm(train_loader): # src_batch: (batch x source_length), tgt_batch: (batch x target_length) optimizer.zero_grad() src_batch, tgt_batch = torch.LongTensor(src_batch).to( args.device), torch.LongTensor(tgt_batch).to(args.device) batch = src_batch.shape[0] # split target batch into input and output tgt_batch_i = tgt_batch[:, :-1] tgt_batch_o = tgt_batch[:, 1:] pred = model(src_batch.to(args.device), tgt_batch_i.to(args.device)) loss = loss_fn(pred.contiguous().view(-1, tgt_vocab_size), tgt_batch_o.contiguous().view(-1)) loss.backward() optimizer.step() train_loss_tmp_meter.update(loss / batch, weight=batch) train_loss_meter.update(train_loss_tmp_meter.avg) spent_time = time.time() - spent_time args.logger.info( "[{}] train loss: {:.3f} took {:.1f} seconds".format( epoch, train_loss_tmp_meter.avg, spent_time)) # validation model.eval() acc_val_tmp_meter = AverageMeter() spent_time = time.time() for src_batch, tgt_batch in tqdm(valid_loader): src_batch, tgt_batch = torch.LongTensor( src_batch), torch.LongTensor(tgt_batch) tgt_batch_i = tgt_batch[:, :-1] tgt_batch_o = tgt_batch[:, 1:] with torch.no_grad(): pred = model(src_batch.to(args.device), tgt_batch_i.to(args.device)) corrects, total = val_check( pred.max(dim=-1)[1].cpu(), tgt_batch_o) acc_val_tmp_meter.update(100 * corrects / total, total) spent_time = time.time() - spent_time args.logger.info( "[{}] validation accuracy: {:.1f} %, took {} seconds".format( epoch, acc_val_tmp_meter.avg, spent_time)) acc_val_meter.update(acc_val_tmp_meter.avg) if epoch % args.save_period == 0: save(args, "epoch_{}".format(epoch), model.state_dict()) acc_val_meter.save() train_loss_meter.save() else: # test args.logger.info("starting test") test_loader = get_loader(src['test'], tgt['test'], src_vocab, tgt_vocab, batch_size=args.batch_size) pred_list = [] model.eval() for src_batch, tgt_batch in test_loader: #src_batch: (batch x source_length) src_batch = torch.Tensor(src_batch).long().to(args.device) batch = src_batch.shape[0] pred_batch = torch.zeros(batch, 1).long().to(args.device) pred_mask = torch.zeros(batch, 1).bool().to( args.device) # mask whether each sentece ended up with torch.no_grad(): for _ in range(args.max_length): pred = model( src_batch, pred_batch) # (batch x length x tgt_vocab_size) pred[:, :, pad_idx] = -1 # ignore <pad> pred = pred.max(dim=-1)[1][:, -1].unsqueeze( -1) # next word prediction: (batch x 1) pred = pred.masked_fill( pred_mask, 2).long() # fill out <pad> for ended sentences pred_mask = torch.gt(pred.eq(1) + pred.eq(2), 0) pred_batch = torch.cat([pred_batch, pred], dim=1) if torch.prod(pred_mask) == 1: break pred_batch = torch.cat([ pred_batch, torch.ones(batch, 1).long().to(args.device) + pred_mask.long() ], dim=1) # close all sentences pred_list += seq2sen(pred_batch.cpu().numpy().tolist(), tgt_vocab) with open('results/pred.txt', 'w', encoding='utf-8') as f: for line in pred_list: f.write('{}\n'.format(line)) os.system( 'bash scripts/bleu.sh results/pred.txt multi30k/test.de.atok')
def main(args): src, tgt = load_data(args.path) src_vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') src_vocab.load(os.path.join(args.path, 'vocab.en')) tgt_vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') tgt_vocab.load(os.path.join(args.path, 'vocab.de')) sos_idx = 0 eos_idx = 1 pad_idx = 2 max_length = 50 src_vocab_size = len(src_vocab) tgt_vocab_size = len(tgt_vocab) # Set hyper parameter device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = make_model(src_vocab_size, tgt_vocab_size).to(device) optimizer = get_std_opt(model) criterion = LabelSmoothing(size=tgt_vocab_size, padding_idx=pad_idx, smoothing=0.1) train_criterion = SimpleLossCompute(model.generator, criterion, optimizer) valid_criterion = SimpleLossCompute(model.generator, criterion, None) print('Using device:', device) if not args.test: train_loader = get_loader(src['train'], tgt['train'], src_vocab, tgt_vocab, batch_size=args.batch_size, shuffle=True) valid_loader = get_loader(src['valid'], tgt['valid'], src_vocab, tgt_vocab, batch_size=args.batch_size) best_loss = 987654321 for epoch in range(args.epochs): train_total_loss, valid_total_loss = 0.0, 0.0 start = time.time() total_tokens = 0 tokens = 0 model.train() # Train for src_batch, tgt_batch in train_loader: src_batch = torch.tensor(src_batch).to(device) tgt_batch = torch.tensor(tgt_batch).to(device) batch = Batch(src_batch, tgt_batch, pad_idx) prediction = model(batch.src, batch.trg, batch.src_mask, batch.trg_mask) loss = train_criterion(prediction, batch.trg_y, batch.ntokens) train_total_loss += loss total_tokens += batch.ntokens tokens += batch.ntokens # Valid model.eval() for src_batch, tgt_batch in valid_loader: src_batch = torch.tensor(src_batch).to(device) tgt_batch = torch.tensor(tgt_batch).to(device) batch = Batch(src_batch, tgt_batch, pad_idx) prediction = model(batch.src, batch.trg, batch.src_mask, batch.trg_mask) loss = valid_criterion(prediction, batch.trg_y, batch.ntokens) valid_total_loss += loss total_tokens += batch.ntokens tokens += batch.ntokens if valid_total_loss.item() < best_loss: best_loss = valid_total_loss best_model_state = model.state_dict() best_optimizer_state = optimizer.optimizer.state_dict() elpsed = time.time() - start print( time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + "|| [" + str(epoch) + "/" + str(args.epochs) + "], train_loss = " + str(train_total_loss.item()) + ", valid_loss = " + str(valid_total_loss.item()) + ", Tokens per Sec = " + str(tokens.item() / elpsed)) tokens = 0 start = time.time() if epoch % 100 == 0: # Save model torch.save( { 'epoch': args.epochs, 'model_state_dict': best_model_state, 'optimizer_state': best_optimizer_state, 'loss': best_loss }, args.model_dir + "/intermediate.pt") print("Model saved") # Save model torch.save( { 'epoch': args.epochs, 'model_state_dict': best_model_state, 'optimizer_state': best_optimizer_state, 'loss': best_loss }, args.model_dir + "/best.pt") print("Model saved") else: # Load the model checkpoint = torch.load(args.model_dir + "/" + args.model_name, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) optimizer.optimizer.load_state_dict(checkpoint['optimizer_state']) model.eval() print("Model loaded") # Test test_loader = get_loader(src['test'], tgt['test'], src_vocab, tgt_vocab, batch_size=args.batch_size) pred = [] for src_batch, tgt_batch in test_loader: src_batch = torch.tensor(src_batch).to(device) tgt_batch = torch.tensor(tgt_batch).to(device) batch = Batch(src_batch, tgt_batch, pad_idx) # Get pred_batch memory = model.encode(batch.src, batch.src_mask) pred_batch = torch.ones(src_batch.size(0), 1)\ .fill_(sos_idx).type_as(batch.src.data).to(device) for i in range(max_length - 1): out = model.decode( memory, batch.src_mask, Variable(pred_batch), Variable( Batch.make_std_mask(pred_batch, pad_idx).type_as(batch.src.data))) prob = model.generator(out[:, -1]) prob.index_fill_(1, torch.tensor([sos_idx, pad_idx]).to(device), -float('inf')) _, next_word = torch.max(prob, dim=1) pred_batch = torch.cat( [pred_batch, next_word.unsqueeze(-1)], dim=1) pred_batch = torch.cat([pred_batch, torch.ones(src_batch.size(0), 1)\ .fill_(eos_idx).type_as(batch.src.data).to(device)], dim=1) # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1). # every <pad> token (index: 2) should be located after <eos> token (index: 1). # example of pred_batch: # [[0, 5, 6, 7, 1], # [0, 4, 9, 1, 2], # [0, 6, 1, 2, 2]] pred += seq2sen(pred_batch.tolist(), tgt_vocab) with open('results/pred.txt', 'w', encoding='utf-8') as f: for line in pred: f.write('{}\n'.format(line)) os.system( 'bash scripts/bleu.sh results/pred.txt multi30k/test.de.atok')
def main(args): src, tgt = load_data(args.path) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) src_vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') src_vocab.load(os.path.join(args.path, 'vocab.en')) tgt_vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') tgt_vocab.load(os.path.join(args.path, 'vocab.de')) sos_idx = 0 eos_idx = 1 pad_idx = 2 max_length = 50 src_vocab_size = len(src_vocab) tgt_vocab_size = len(tgt_vocab) N = 6 dim = 512 # MODEL Construction encoder = Encoder(N, dim, pad_idx, src_vocab_size, device).to(device) decoder = Decoder(N, dim, pad_idx, tgt_vocab_size, device).to(device) if args.model_load: ckpt = torch.load("drive/My Drive/checkpoint/best.ckpt") encoder.load_state_dict(ckpt["encoder"]) decoder.load_state_dict(ckpt["decoder"]) params = list(encoder.parameters()) + list(decoder.parameters()) if not args.test: train_loader = get_loader(src['train'], tgt['train'], src_vocab, tgt_vocab, batch_size=args.batch_size, shuffle=True) valid_loader = get_loader(src['valid'], tgt['valid'], src_vocab, tgt_vocab, batch_size=args.batch_size) warmup = 4000 steps = 1 lr = 1. * (dim**-0.5) * min(steps**-0.5, steps * (warmup**-1.5)) optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.98), eps=1e-09) train_losses = [] val_losses = [] latest = 1e08 # to store latest checkpoint start_epoch = 0 if (args.model_load): start_epoch = ckpt["epoch"] optimizer.load_state_dict(ckpt["optim"]) steps = start_epoch * 30 for epoch in range(start_epoch, args.epochs): for src_batch, tgt_batch in train_loader: encoder.train() decoder.train() optimizer.zero_grad() tgt_batch = torch.LongTensor(tgt_batch) src_batch = Variable(torch.LongTensor(src_batch)).to(device) gt = Variable(tgt_batch[:, 1:]).to(device) tgt_batch = Variable(tgt_batch[:, :-1]).to(device) enc_output, seq_mask = encoder(src_batch) dec_output = decoder(tgt_batch, enc_output, seq_mask) gt = gt.view(-1) dec_output = dec_output.view(gt.size()[0], -1) loss = F.cross_entropy(dec_output, gt, ignore_index=pad_idx) loss.backward() train_losses.append(loss.item()) optimizer.step() steps += 1 lr = (dim**-0.5) * min(steps**-0.5, steps * (warmup**-1.5)) update_lr(optimizer, lr) if (steps % 10 == 0): print("loss : %f" % loss.item()) for src_batch, tgt_batch in valid_loader: encoder.eval() decoder.eval() src_batch = Variable(torch.LongTensor(src_batch)).to(device) tgt_batch = torch.LongTensor(tgt_batch) gt = Variable(tgt_batch[:, 1:]).to(device) tgt_batch = Variable(tgt_batch[:, :-1]).to(device) enc_output, seq_mask = encoder(src_batch) dec_output = decoder(tgt_batch, enc_output, seq_mask) gt = gt.view(-1) dec_output = dec_output.view(gt.size()[0], -1) loss = F.cross_entropy(dec_output, gt, ignore_index=pad_idx) val_losses.append(loss.item()) print("[EPOCH %d] Loss %f" % (epoch, loss.item())) if (val_losses[-1] <= latest): checkpoint = {'encoder':encoder.state_dict(), 'decoder':decoder.state_dict(), \ 'optim':optimizer.state_dict(), 'epoch':epoch} torch.save(checkpoint, "drive/My Drive/checkpoint/best.ckpt") latest = val_losses[-1] if (epoch % 20 == 0): plt.figure() plt.plot(val_losses) plt.xlabel("epoch") plt.ylabel("model loss") plt.show() else: # test test_loader = get_loader(src['test'], tgt['test'], src_vocab, tgt_vocab, batch_size=args.batch_size) # LOAD CHECKPOINT pred = [] for src_batch, tgt_batch in test_loader: encoder.eval() decoder.eval() b_s = min(args.batch_size, len(src_batch)) tgt_batch = torch.zeros(b_s, 1).to(device).long() src_batch = Variable(torch.LongTensor(src_batch)).to(device) enc_output, seq_mask = encoder(src_batch) pred_batch = decoder(tgt_batch, enc_output, seq_mask) _, pred_batch = torch.max(pred_batch, 2) while (not is_finished(pred_batch, max_length, eos_idx)): # do something next_input = torch.cat((tgt_batch, pred_batch.long()), 1) pred_batch = decoder(next_input, enc_output, seq_mask) _, pred_batch = torch.max(pred_batch, 2) # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1). # every <pad> token (index: 2) should be located after <eos> token (index: 1). # example of pred_batch: # [[0, 5, 6, 7, 1], # [0, 4, 9, 1, 2], # [0, 6, 1, 2, 2]] pred_batch = pred_batch.tolist() for line in pred_batch: line[-1] = 1 pred += seq2sen(pred_batch, tgt_vocab) # print(pred) with open('results/pred.txt', 'w') as f: for line in pred: f.write('{}\n'.format(line)) os.system( 'bash scripts/bleu.sh results/pred.txt multi30k/test.de.atok')
def main(args): src, tgt = load_data(args.path) src_vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') src_vocab.load(os.path.join(args.path, 'vocab.en')) tgt_vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') tgt_vocab.load(os.path.join(args.path, 'vocab.de')) vsize_src = len(src_vocab) vsize_tar = len(tgt_vocab) net = Transformer(vsize_src, vsize_tar) if not args.test: train_loader = get_loader(src['train'], tgt['train'], src_vocab, tgt_vocab, batch_size=args.batch_size, shuffle=True) valid_loader = get_loader(src['valid'], tgt['valid'], src_vocab, tgt_vocab, batch_size=args.batch_size) net.to(device) optimizer = optim.Adam(net.parameters(), lr=args.lr) best_valid_loss = 10.0 for epoch in range(args.epochs): print("Epoch {0}".format(epoch)) net.train() train_loss = run_epoch(net, train_loader, optimizer) print("train loss: {0}".format(train_loss)) net.eval() valid_loss = run_epoch(net, valid_loader, None) print("valid loss: {0}".format(valid_loss)) torch.save(net, 'data/ckpt/last_model') if valid_loss < best_valid_loss: best_valid_loss = valid_loss torch.save(net, 'data/ckpt/best_model') else: # test net = torch.load('data/ckpt/best_model') net.to(device) net.eval() test_loader = get_loader(src['test'], tgt['test'], src_vocab, tgt_vocab, batch_size=args.batch_size) pred = [] iter_cnt = 0 for src_batch, tgt_batch in test_loader: source, src_mask = make_tensor(src_batch) source = source.to(device) src_mask = src_mask.to(device) res = net.decode(source, src_mask) pred_batch = res.tolist() # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1). # every <pad> token (index: 2) should be located after <eos> token (index: 1). # example of pred_batch: # [[0, 5, 6, 7, 1], # [0, 4, 9, 1, 2], # [0, 6, 1, 2, 2]] pred += seq2sen(pred_batch, tgt_vocab) iter_cnt += 1 #print(pred_batch) with open('data/results/pred.txt', 'w') as f: for line in pred: f.write('{}\n'.format(line)) os.system( 'bash scripts/bleu.sh data/results/pred.txt data/multi30k/test.de.atok' )
def main(args): src, tgt = load_data(args.path) src_vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') src_vocab.load(os.path.join(args.path, 'vocab.en')) tgt_vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') tgt_vocab.load(os.path.join(args.path, 'vocab.de')) # TODO: use these information. sos_idx = 0 eos_idx = 1 pad_idx = 2 max_length = 50 # TODO: use these values to construct embedding layers src_vocab_size = len(src_vocab) tgt_vocab_size = len(tgt_vocab) if not args.test: train_loader = get_loader(src['train'], tgt['train'], src_vocab, tgt_vocab, batch_size=args.batch_size, shuffle=True) valid_loader = get_loader(src['valid'], tgt['valid'], src_vocab, tgt_vocab, batch_size=args.batch_size) # TODO: train for epoch in range(args.epochs): for src_batch, tgt_batch in train_loader: pass # TODO: validation for src_batch, tgt_batch in valid_loader: pass else: # test test_loader = get_loader(src['test'], tgt['test'], src_vocab, tgt_vocab, batch_size=args.batch_size) pred = [] for src_batch, tgt_batch in test_loader: # TODO: predict pred_batch from src_batch with your model. pred_batch = tgt_batch # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1). # every <pad> token (index: 2) should be located after <eos> token (index: 1). # example of pred_batch: # [[0, 5, 6, 7, 1], # [0, 4, 9, 1, 2], # [0, 6, 1, 2, 2]] pred += seq2sen(pred_batch, tgt_vocab) with open('results/pred.txt', 'w') as f: for line in pred: f.write('{}\n'.format(line)) os.system( 'bash scripts/bleu.sh results/pred.txt multi30k/test.de.atok')