Beispiel #1
0
    def embed(self,
              dataloader: torch.utils.data.DataLoader,
              device: torch.device = None,
              to_numpy: bool = True):
        """Embeds images in a vector space.

        Args:
            dataloader:
                A torchvision dataloader.
            device:
                Selected device (see PyTorch documentation)
            to_numpy:
                Whether to return the embeddings as numpy array.

        Returns:
            A tensor or ndarray of embeddings with shape n_images x num_ftrs

        Examples:
            >>> # embed images in vector space
            >>> embeddings, _, _ = encoder.embed(dataloader)

        """

        self.model.eval()
        embeddings, labels, fnames = None, None, []

        if is_prefetch_generator_available():
            pbar = tqdm(BackgroundGenerator(dataloader, max_prefetch=3),
                        total=len(dataloader))
        else:
            pbar = tqdm(dataloader, total=len(dataloader))

        efficiency = 0.
        embeddings = []
        labels = []
        with torch.no_grad():

            start_time = time.time()
            for (img, label, fname) in pbar:

                img = img.to(device)
                label = label.to(device)

                fnames += [*fname]

                batch_size = img.shape[0]
                prepare_time = time.time()

                emb = self.model.features(img)
                emb = emb.detach().reshape(batch_size, -1)

                embeddings.append(emb)
                labels.append(label)

                process_time = time.time()

                efficiency = \
                    (process_time - prepare_time) / (process_time - start_time)
                pbar.set_description(
                    "Compute efficiency: {:.2f}".format(efficiency))
                start_time = time.time()

            embeddings = torch.cat(embeddings, 0)
            labels = torch.cat(labels, 0)
            if to_numpy:
                embeddings = embeddings.cpu().numpy()
                labels = labels.cpu().numpy()

        return embeddings, labels, fnames
Beispiel #2
0
""" Embedding Strategies """

# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved

import time

import torch
from lightly import is_prefetch_generator_available
from lightly.embedding._base import BaseEmbedding
from tqdm import tqdm

if is_prefetch_generator_available():
    from prefetch_generator import BackgroundGenerator


class SelfSupervisedEmbedding(BaseEmbedding):
    """Implementation of self-supervised embedding models.

    Implements an embedding strategy based on self-supervised learning. A
    model backbone, self-supervised criterion, optimizer, and dataloader are
    passed to the constructor. The embedding itself is a pytorch-lightning
    module which can be trained very easily:

    https://pytorch-lightning.readthedocs.io/en/stable/

    The implementation is based on contrastive learning.

    MCM: https://arxiv.org/abs/1906.05849

    SimCLR: https://arxiv.org/abs/2002.05709