class FeatureExtractor(object): def __init__(self, checkpoint, image_encoder, dataset): self.load_checkpoint(checkpoint) self.image_encoder = image_encoder self.dataset = dataset def load_checkpoint(self, checkpoint): checkpoint = torch.load(checkpoint) opt = checkpoint['opt'] opt.use_external_captions = False vocab = Vocab.from_pickle(pjoin(opt.vocab_path, '%s_vocab.pkl' % opt.data_name)) opt.vocab_size = len(vocab) from model import VSE self.model = VSE(opt) self.model.load_state_dict(checkpoint['model']) self.projector = vocab self.model.img_enc.eval() self.model.txt_enc.eval() for p in self.model.img_enc.parameters(): p.requires_grad = False for p in self.model.txt_enc.parameters(): p.requires_grad = False def __call__(self, ind): raw_img, img, img_embedding, cap, cap_ext = self.dataset[ind] img_embedding_precomp = self.model.img_enc(as_cuda(as_variable(img_embedding).unsqueeze(0))) img = as_variable(img) img.requires_grad = True img_embedding_a = img_embedding = self.image_encoder(as_cuda(img.unsqueeze(0))) img_embedding = self.model.img_enc(img_embedding) txt = [cap] txt.extend(cap_ext) txt_embeddings, txt_var = self.enc_txt(txt) return Record( raw_img, cap, cap_ext, img, img_embedding, img_embedding_precomp, txt_var, txt_embeddings[0], txt_embeddings[1:] ) def enc_txt(self, caps): sents, lengths, _, inv = _prepare_batch(caps, self.projector) inv = var_with(as_variable(inv), sents) out, x = self.model.txt_enc.forward(sents, lengths, True) return out[inv], x
class ImageRetriever(object): def __init__(self, model_path, img_path, precomp_path, split, vocab_path, batch_size): checkpoint = torch.load(model_path) self.opt = checkpoint['opt'] self.opt.data_path = img_path self.opt.data_name = 'coco' self.opt.batch_size = batch_size self.img_loader = get_test_loader(split, self.opt.data_name, None, self.opt.crop_size, self.opt.batch_size, self.opt.workers, self.opt) self.opt.data_path = precomp_path self.opt.data_name = 'coco_precomp' # load vocabulary used by the model with open(vocab_path, 'rb') as f: self.vocab = pickle.load(f) self.vocab_size = len(self.vocab) self.precomp_loader = get_test_loader(split, self.opt.data_name, self.vocab, self.opt.crop_size, self.opt.batch_size, self.opt.workers, self.opt) self.model = VSE(self.opt) # load model state self.model.load_state_dict(checkpoint['model']) # precompute all image embeddings self.model.val_start() # numpy array to keep all the embeddings self.img_embs = None batch_time = AverageMeter() val_logger = LogCollector() end = time.time() log_step = 10 for i, (images, captions, lengths, ids) in enumerate(self.precomp_loader): self.model.logger = val_logger # compute the embeddings images = Variable(images, volatile=True) if torch.cuda.is_available(): images = images.cuda() # Forward img_emb = self.model.img_enc(images) # initialize the numpy arrays given the size of the embeddings if self.img_embs is None: self.img_embs = np.zeros( (len(self.precomp_loader.dataset), img_emb.size(1))) # preserve the embeddings by copying from gpu and converting to numpy self.img_embs[ids] = img_emb.data.cpu().numpy().copy() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % log_step == 0: print('Test: [{0}/{1}]\t' '{e_log}\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'. format(i, len(self.precomp_loader), batch_time=batch_time, e_log=str(self.model.logger))) del images, captions def get_NN(self, query, measure='dot', k=5): # convert query from string to index tokens = nltk.tokenize.word_tokenize( str(query).lower().decode('utf-8')) query = [] query.append(self.vocab('<start>')) query.extend([self.vocab(token) for token in tokens]) query.append(self.vocab('<end>')) # embedd query query = Variable(torch.LongTensor(query), volatile=True) if torch.cuda.is_available(): query = query.cuda() # Forward q_emb = self.model.txt_enc(query, [query.size(0)]) q_embs = q_emb.data.cpu().numpy().copy() # run nearest neighbours searching text -> image return self.find_NN(self.img_embs, q_embs, measure=measure, k=k) def find_NN(self, images, query, measure, k, npts=None): """ Text->Images (Image Search) Images: (5N, K) matrix of images query: (1, K) matrix of captions """ if npts is None: npts = images.shape[0] / 5 ims = numpy.array([images[i] for i in range(0, len(images), 5)]) # Compute scores tic = time.clock() if measure == 'order': d2 = order_sim( torch.Tensor(ims).cuda(), torch.Tensor(query).cuda()) d2 = d2.cpu().numpy() d = d2.T else: d = numpy.dot( query, ims.T ) # TODO Try to optimize this computation, see if sorting is bottleneck imgs_by_similarity = numpy.argsort(numpy.squeeze(d))[::-1] toc = time.clock() print('NN search took {} ms over {} images'.format( (toc - tic) * 1000.0, ims.shape[0])) return imgs_by_similarity[0:k] * 5, query, ims def visualise_NN(self, inds, img_identifier, file_name): html_file = open(file_name, "w") for i, ind in enumerate(inds): root, caption, img_id, path, image = self.img_loader.dataset.get_raw_item( ind) image.save("./out/{}_img{}.png".format(img_identifier, i + 1)) img_tag = '<img src="./out/{}_img{}.png" style="max-height: 400px; max-width: 400px;">'\ .format(img_identifier, i + 1) html_file.write(img_tag) html_file.close()