class FeatureExtractor(object):
    def __init__(self, checkpoint, image_encoder, dataset):
        self.load_checkpoint(checkpoint)
        self.image_encoder = image_encoder
        self.dataset = dataset

    def load_checkpoint(self, checkpoint):
        checkpoint = torch.load(checkpoint)
        opt = checkpoint['opt']
        opt.use_external_captions = False
        vocab = Vocab.from_pickle(pjoin(opt.vocab_path, '%s_vocab.pkl' % opt.data_name))
        opt.vocab_size = len(vocab)

        from model import VSE
        self.model = VSE(opt)
        self.model.load_state_dict(checkpoint['model'])
        self.projector = vocab

        self.model.img_enc.eval()
        self.model.txt_enc.eval()
        for p in self.model.img_enc.parameters():
            p.requires_grad = False
        for p in self.model.txt_enc.parameters():
            p.requires_grad = False

    def __call__(self, ind):
        raw_img, img, img_embedding, cap, cap_ext = self.dataset[ind]
        img_embedding_precomp = self.model.img_enc(as_cuda(as_variable(img_embedding).unsqueeze(0)))

        img = as_variable(img)
        img.requires_grad = True
        img_embedding_a = img_embedding = self.image_encoder(as_cuda(img.unsqueeze(0)))
        img_embedding = self.model.img_enc(img_embedding)

        txt = [cap]
        txt.extend(cap_ext)
        txt_embeddings, txt_var = self.enc_txt(txt)

        return Record(
                raw_img, cap, cap_ext,
                img, img_embedding, img_embedding_precomp,
                txt_var, txt_embeddings[0], txt_embeddings[1:]
        )

    def enc_txt(self, caps):
        sents, lengths, _, inv = _prepare_batch(caps, self.projector)
        inv = var_with(as_variable(inv), sents)
        out, x = self.model.txt_enc.forward(sents, lengths, True)
        return out[inv], x
Exemple #2
0
class ImageRetriever(object):
    def __init__(self, model_path, img_path, precomp_path, split, vocab_path,
                 batch_size):
        checkpoint = torch.load(model_path)
        self.opt = checkpoint['opt']
        self.opt.data_path = img_path
        self.opt.data_name = 'coco'
        self.opt.batch_size = batch_size

        self.img_loader = get_test_loader(split, self.opt.data_name, None,
                                          self.opt.crop_size,
                                          self.opt.batch_size,
                                          self.opt.workers, self.opt)

        self.opt.data_path = precomp_path
        self.opt.data_name = 'coco_precomp'

        # load vocabulary used by the model
        with open(vocab_path, 'rb') as f:
            self.vocab = pickle.load(f)
        self.vocab_size = len(self.vocab)

        self.precomp_loader = get_test_loader(split, self.opt.data_name,
                                              self.vocab, self.opt.crop_size,
                                              self.opt.batch_size,
                                              self.opt.workers, self.opt)

        self.model = VSE(self.opt)

        # load model state
        self.model.load_state_dict(checkpoint['model'])

        # precompute all image embeddings
        self.model.val_start()

        # numpy array to keep all the embeddings
        self.img_embs = None
        batch_time = AverageMeter()
        val_logger = LogCollector()
        end = time.time()
        log_step = 10
        for i, (images, captions, lengths,
                ids) in enumerate(self.precomp_loader):
            self.model.logger = val_logger

            # compute the embeddings
            images = Variable(images, volatile=True)
            if torch.cuda.is_available():
                images = images.cuda()

            # Forward
            img_emb = self.model.img_enc(images)

            # initialize the numpy arrays given the size of the embeddings
            if self.img_embs is None:
                self.img_embs = np.zeros(
                    (len(self.precomp_loader.dataset), img_emb.size(1)))

            # preserve the embeddings by copying from gpu and converting to numpy
            self.img_embs[ids] = img_emb.data.cpu().numpy().copy()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % log_step == 0:
                print('Test: [{0}/{1}]\t'
                      '{e_log}\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'.
                      format(i,
                             len(self.precomp_loader),
                             batch_time=batch_time,
                             e_log=str(self.model.logger)))
            del images, captions

    def get_NN(self, query, measure='dot', k=5):
        # convert query from string to index
        tokens = nltk.tokenize.word_tokenize(
            str(query).lower().decode('utf-8'))
        query = []
        query.append(self.vocab('<start>'))
        query.extend([self.vocab(token) for token in tokens])
        query.append(self.vocab('<end>'))

        # embedd query
        query = Variable(torch.LongTensor(query), volatile=True)
        if torch.cuda.is_available():
            query = query.cuda()

        # Forward
        q_emb = self.model.txt_enc(query, [query.size(0)])
        q_embs = q_emb.data.cpu().numpy().copy()

        # run nearest neighbours searching text -> image
        return self.find_NN(self.img_embs, q_embs, measure=measure, k=k)

    def find_NN(self, images, query, measure, k, npts=None):
        """
        Text->Images (Image Search)
        Images: (5N, K) matrix of images
        query: (1, K) matrix of captions
        """
        if npts is None:
            npts = images.shape[0] / 5
        ims = numpy.array([images[i] for i in range(0, len(images), 5)])

        # Compute scores
        tic = time.clock()
        if measure == 'order':
            d2 = order_sim(
                torch.Tensor(ims).cuda(),
                torch.Tensor(query).cuda())
            d2 = d2.cpu().numpy()

            d = d2.T
        else:
            d = numpy.dot(
                query, ims.T
            )  # TODO Try to optimize this computation, see if sorting is bottleneck

        imgs_by_similarity = numpy.argsort(numpy.squeeze(d))[::-1]
        toc = time.clock()
        print('NN search took {} ms over {} images'.format(
            (toc - tic) * 1000.0, ims.shape[0]))

        return imgs_by_similarity[0:k] * 5, query, ims

    def visualise_NN(self, inds, img_identifier, file_name):
        html_file = open(file_name, "w")
        for i, ind in enumerate(inds):
            root, caption, img_id, path, image = self.img_loader.dataset.get_raw_item(
                ind)
            image.save("./out/{}_img{}.png".format(img_identifier, i + 1))
            img_tag = '<img src="./out/{}_img{}.png" style="max-height: 400px; max-width: 400px;">'\
                .format(img_identifier, i + 1)
            html_file.write(img_tag)
        html_file.close()