def train(**kwargs): device = t.device('cuda') if t.cuda.is_available() else t.device('cpu') for k, v in kwargs.items(): setattr(opt, k, v) dataloader = get_dataloader(opt) model = CaptionModel(opt, dataloader.dataset.word2ix, dataloader.dataset.id2ix) if opt.model_path: model.load_state_dict(t.load(opt.model_path, map_location='cpu')) t.backends.cudnn.enabled = False model = model.to(device) optimizer = Adam(model.parameters(), opt.lr) criterion = t.nn.CrossEntropyLoss() for epoch in range(opt.max_epoch): for ii, (imgs, (captions, lengths), indexes) in tqdm.tqdm(enumerate(dataloader)): imgs = Variable(imgs).to(device) captions = Variable(captions).to(device) pred, _ = model(imgs, captions, lengths) target_captions = pack_padded_sequence(captions, lengths)[0] loss = criterion(pred, target_captions) optimizer.zero_grad() loss.backward() optimizer.step() print("Current Loss: ", loss.item()) if (epoch + 1) % opt.save_model == 0: t.save(model.state_dict(), "checkpoints/{}.pth".format(epoch))
def generate(**kwargs): device = t.device('cuda') if t.cuda.is_available else t.device('cpu') for k, v in kwargs.items(): setattr(opt, k, v) data = t.load(opt.caption_path, map_location=lambda s, l: s) word2ix, ix2word = data['word2ix'], data['ix2word'] transforms = tv.transforms.Compose([ tv.transforms.Resize(224), tv.transforms.CenterCrop(224), tv.transforms.ToTensor(), tv.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) img = Image.open(opt.test_img) img = transforms(img).unsqueeze(0) resnet50 = tv.models.resnet50(True).eval() del resnet50.fc resnet50.fc = lambda x: x # resnet50 = resnet50.to(device) # img = img.to(device) img_feats = resnet50(img).detach() # Caption Model model = CaptionModel(opt, word2ix, ix2word) model.load_state_dict(t.load(opt.model_path, map_location='cpu')) # model.to(device) results = model.generate(img_feats.data[0]) print('\r\n'.join(results))
xe_criterion.cuda() rl_criterion.cuda() logger.info('Start training...') start = datetime.now() optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate) infos = train(model, xe_criterion, optimizer, train_loader, val_loader, opt, rl_criterion=rl_criterion) logger.info('Best val %s score: %f. Best iter: %d. Best epoch: %d', opt.eval_metric, infos['best_score'], infos['best_iter'], infos['best_epoch']) logger.info('Training time: %s', datetime.now() - start) if opt.result_file: logger.info('Start testing...') start = datetime.now() logger.info('Loading model: %s', opt.model_file) checkpoint = torch.load(opt.model_file) model.load_state_dict(checkpoint['model']) test(model, xe_criterion, test_loader, opt) logger.info('Testing time: %s', datetime.now() - start)
transform=mytransform, train=True) flicker8k_val = FlickrDataLoader.Flicker8k(img_dir, cap_path, val_txt, transform=mytransform, train=True) with open('feat6k.npy', 'r') as f: feat_tr = np.load(f) with open('capt6k.pkl', 'r') as f: caption_trn = pickle.load(f) with open('feat.pkl', 'r') as f: feat_val = pickle.load(f) with open('capt1k.pkl', 'r') as f: caption_val = pickle.load(f) model = CaptionModel(bsz=1, feat_dim=(196, 512), n_voc=5834, n_embed=512, n_hidden=1024).cuda() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) train(epoches=1) with open('model_t.pth', 'r') as f: model.load_state_dict(torch.load(f))