예제 #1
0
    def _compute_fid_score(self) -> float:
        """Computes FID score for the dataset

        Returns:
            float: FID score
        """

        # load dataset
        path = self._config['dataset']['path']
        anno = self._config['dataset']['anno']

        # all images should be resized to 299 for the network, used for to calculate FID
        size = 299

        # updated version of MakeDataLoader is used, where size of the image can be changes
        # and custom paths to data and annotations can be passed directly
        make_dl = MakeDataLoader(path, anno, size, augmented=False)
        ds_test = make_dl.dataset_test
        ds_valid = make_dl.dataset_valid
        n_samples = len(ds_test)

        fid_func = get_fid_fn(ds_test, ds_valid, self._device, n_samples)

        with torch.no_grad():
            fid_score = fid_func(self._generator)
        return fid_score
예제 #2
0
    def _get_dl(self) -> DataLoader:
        """Creates infinite dataloader from valid and test sets of images

        Returns:
            DataLoader: created dataloader
        """

        bs = self._config['batch_size']
        n_workers = self._config['n_workers']

        # load dataset
        path = self._config['dataset']['path']
        anno = self._config['dataset']['anno']
        size = self._config['dataset']['size']
        make_dl = MakeDataLoader(path, anno, size, augmented=True)
        ds_valid = make_dl.dataset_valid
        ds_test = make_dl.dataset_test
        ds = ConcatDataset([ds_valid, ds_test])
        dl = infinite_loader(
            DataLoader(ds, bs, True, num_workers=n_workers, drop_last=True))
        return dl
예제 #3
0
    def _compute_inception_score(self) -> float:
        """Computes inception score (IS) for the model

        Returns:
            float: inception score (IS)
        """

        batch_size = self._config['batch_size']

        # load dataset
        path = self._config['dataset']['path']
        anno = self._config['dataset']['anno']
        size = self._config['dataset']['size']

        make_dl = MakeDataLoader(path, anno, size, augmented=False)
        ds = make_dl.dataset_valid
        n_samples = len(ds)

        dataset = GANDataset(self._generator, ds, self._device, n_samples)
        score = inception_score(dataset, batch_size=batch_size, resize=True)[0]
        return score
예제 #4
0
    def _chamfer_distance(self, encoder_type: str = 'simclr') -> float:
        """Computes Chamfer distance between real and generated samples

        Args:
            encoder_type: type of encoder to use. Choices: simclr, ae

        Returns:
            float: Chamfer distance
        """

        if encoder_type not in ['simclr', 'ae']:
            raise ValueError('Incorrect encoder')

        if encoder_type == 'simclr':
            encoder = self._encoder
        else:
            encoder = Encoder().to(self._device).eval()
            ckpt_encoder = torch.load(self._config['eval']['path_encoder'])
            encoder.load_state_dict(ckpt_encoder)

        bs = self._config['batch_size']

        path = self._config['dataset']['path']
        anno = self._config['dataset']['anno']
        size = self._config['dataset']['size']

        make_dl = MakeDataLoader(path, anno, size, augmented=True)
        dl_valid = make_dl.get_data_loader_valid(bs)
        dl_test = make_dl.get_data_loader_test(bs)

        embeddings_real = []
        embeddings_gen = []

        for batch_val, batch_test in zip(dl_valid, dl_test):
            img, _ = batch_test
            _, lbl = batch_val

            img = img.to(self._device)
            img = (img - 0.5) / 0.5  # renormalize

            s = img.shape[0]

            lbl = lbl.to(self._device)
            latent = torch.randn((s, self._generator.dim_z)).to(self._device)

            with torch.no_grad():
                img_gen = self._generator(latent, lbl)
                img_gen = (img_gen - 0.5) / 0.5  # renormalize

                h, _ = encoder(img)
                h_gen, _ = encoder(img_gen)

            embeddings_real.extend(h.cpu().numpy())
            embeddings_gen.extend(h_gen.cpu().numpy())

        embeddings_real = np.array(embeddings_real, dtype=np.float32)
        embeddings_gen = np.array(embeddings_gen, dtype=np.float32)
        embeddings = np.concatenate((embeddings_real, embeddings_gen))
        tsne_emb = TSNE(n_components=3, n_jobs=16).fit_transform(embeddings)

        n = len(tsne_emb)
        tsne_real = np.array(tsne_emb[:n // 2, ], dtype=np.float32)
        tsne_fake = np.array(tsne_emb[n // 2:, ], dtype=np.float32)

        tsne_real = torch.from_numpy(tsne_real).unsqueeze(0)
        tsne_fake = torch.from_numpy(tsne_fake).unsqueeze(0)

        chamfer_dist = ChamferDistance()
        return chamfer_dist(tsne_real, tsne_fake).detach().item()
예제 #5
0
    def _compute_ssl_fid(self, encoder_type: str = 'simclr') -> float:
        """Computes FID on SSL features

        Args:
            encoder_type: type of encoder to use. Choices: simclr, ae

        Returns:
            float: FID
        """

        if encoder_type not in ['simclr', 'ae']:
            raise ValueError('Incorrect encoder')

        if encoder_type == 'simclr':
            encoder = self._encoder
        else:
            encoder = Encoder().to(self._device).eval()
            ckpt_encoder = torch.load(self._config['eval']['path_encoder'])
            encoder.load_state_dict(ckpt_encoder)

        bs = self._config['batch_size']

        path = self._config['dataset']['path']
        anno = self._config['dataset']['anno']
        size = self._config['dataset']['size']

        make_dl = MakeDataLoader(path, anno, size, augmented=False)
        dl_valid = make_dl.get_data_loader_valid(bs)
        dl_test = make_dl.get_data_loader_test(bs)

        # compute activations
        activations_real = []
        activations_fake = []

        for batch_val, batch_test in zip(dl_valid, dl_test):
            img, _ = batch_test
            _, lbl = batch_val

            img = img.to(self._device)
            img = (img - 0.5) / 0.5

            lbl = lbl.to(self._device)
            latent = torch.randn((bs, self._generator.dim_z)).to(self._device)

            with torch.no_grad():
                img_gen = self._generator(latent, lbl)
                img_gen = (img_gen - 0.5) / 0.5

                h, _ = encoder(img)
                h_gen, _ = encoder(img_gen)

            activations_real.extend(h.cpu().numpy())
            activations_fake.extend(h_gen.cpu().numpy())

        activations_real = np.array(activations_real)
        activations_fake = np.array(activations_fake)

        mu_real = np.mean(activations_real, axis=0)
        sigma_real = np.cov(activations_real, rowvar=False)

        mu_fake = np.mean(activations_fake, axis=0)
        sigma_fake = np.cov(activations_fake, rowvar=False)
        fletcher_distance = calculate_frechet_distance(mu_fake, sigma_fake,
                                                       mu_real, sigma_real)
        return fletcher_distance
예제 #6
0
    def _compute_kid(self, encoder_type: str = 'simclr') -> float:
        """Computes KID score

        Args:
            encoder_type: type of encoder to use. Choices: simclr, inception, ae

        Returns:
            float: KID score
        """

        if encoder_type not in ['simclr', 'inception', 'ae']:
            raise ValueError('Incorrect encoder')

        if encoder_type == 'simclr':
            encoder = self._encoder
        elif encoder_type == 'inception':
            encoder = load_patched_inception_v3().to(self._device).eval()
        else:
            encoder = Encoder().to(self._device).eval()
            ckpt_encoder = torch.load(self._config['eval']['path_encoder'])
            encoder.load_state_dict(ckpt_encoder)

        bs = self._config['batch_size']
        path = self._config['dataset']['path']
        anno = self._config['dataset']['anno']
        size = self._config['dataset']['size']
        make_dl = MakeDataLoader(path, anno, size, augmented=False)
        dl_valid = make_dl.get_data_loader_valid(bs)
        dl_test = make_dl.get_data_loader_test(bs)

        features_real = []
        features_gen = []

        for batch_val, batch_test in zip(dl_valid, dl_test):
            _, lbl = batch_val
            img, _ = batch_test
            img = img.to(self._device)
            img = (img - 0.5) / 0.5  # renormalize

            lbl = lbl.to(self._device)
            latent = torch.randn((bs, self._generator.dim_z)).to(self._device)

            with torch.no_grad():
                img_gen = self._generator(latent, lbl)

                if encoder_type == 'inception':
                    if img.shape[2] != 299 or img.shape[3] != 299:
                        img = torch.nn.functional.interpolate(img,
                                                              size=(299, 299),
                                                              mode='bicubic')

                    img_gen = torch.nn.functional.interpolate(img_gen,
                                                              size=(299, 299),
                                                              mode='bicubic')
                    img_gen = (img_gen - 0.5) / 0.5

                    h = encoder(img)[0].flatten(start_dim=1)
                    h_gen = encoder(img_gen)[0].flatten(start_dim=1)
                else:
                    h, _ = encoder(img)
                    h_gen, _ = encoder(img_gen)
            features_real.extend(h.cpu().numpy())
            features_gen.extend(h_gen.cpu().numpy())

        features_real = np.array(features_real)
        features_gen = np.array(features_gen)
        m = 1000  # max subset size
        num_subsets = 100

        n = features_real.shape[1]
        t = 0
        for _ in range(num_subsets):
            x = features_gen[np.random.choice(features_gen.shape[0],
                                              m,
                                              replace=False)]
            y = features_real[np.random.choice(features_real.shape[0],
                                               m,
                                               replace=False)]
            a = (x @ x.T / n + 1)**3 + (y @ y.T / n + 1)**3
            b = (x @ y.T / n + 1)**3
            t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
        kid = t / num_subsets / m
        return float(kid)
예제 #7
0
    def _compute_geometric_distance(self,
                                    encoder_type: str = 'simclr') -> float:
        """Computes geometric distance between real and generated samples using
        features computed using SimCLR

        Args:
            encoder_type: type of encoder to use. Choices: simclr, ae

        Returns:
            float: geometric distance
        """

        if encoder_type not in ['simclr', 'ae']:
            raise ValueError('Incorrect encoder')

        if encoder_type == 'simclr':
            encoder = self._encoder
        else:
            encoder = Encoder().to(self._device).eval()
            ckpt_encoder = torch.load(self._config['eval']['path_encoder'])
            encoder.load_state_dict(ckpt_encoder)

        loss = SamplesLoss("sinkhorn",
                           p=2,
                           blur=0.05,
                           scaling=0.8,
                           backend="tensorized")

        bs = self._config['batch_size']

        # load dataset
        path = self._config['dataset']['path']
        anno = self._config['dataset']['anno']
        size = self._config['dataset']['size']

        make_dl = MakeDataLoader(path, anno, size, augmented=True)
        dl_valid = make_dl.get_data_loader_valid(bs)
        dl_test = make_dl.get_data_loader_test(bs)

        embeddings_real = []
        embeddings_gen = []

        for batch_val, batch_test in zip(dl_valid, dl_test):
            _, lbl = batch_val
            img, _ = batch_test
            img = img.to(self._device)
            img = (img - 0.5) / 0.5  # renormalize image
            lbl = lbl.to(self._device)
            latent = torch.randn((bs, self._generator.dim_z)).to(self._device)

            with torch.no_grad():
                img_gen = self._generator(latent, lbl)
                h, _ = encoder(img)
                h_gen, _ = encoder(img_gen)

            embeddings_real.extend(h.detach().cpu())
            embeddings_gen.extend(h_gen.detach().cpu())

        embeddings_real = torch.stack(embeddings_real)
        embeddings_gen = torch.stack(embeddings_gen)
        distance = loss(embeddings_real, embeddings_gen)
        return distance.detach().cpu().item()
예제 #8
0
    def _attribute_control_accuracy(self,
                                    build_hist: bool = True,
                                    out_dir: PathOrStr = './images') -> Dict:
        """Computes attribute control accuracy

        Args:
            build_hist: if True, the histogram of differences for each label will be built and saved

            out_dir: path to directory, where to save histogram images

        Returns:
            Dict: attribute control accuracy for each label
        """

        # load dataset
        bs = self._config['batch_size']

        path = self._config['dataset']['path']
        anno = self._config['dataset']['anno']
        size = self._config['dataset']['size']
        make_dl = MakeDataLoader(path, anno, size, augmented=True)
        dl_valid = make_dl.get_data_loader_valid(batch_size=bs)

        n_out = self._config['dataset']['n_out']
        diffs = []
        labels = []

        for batch in tqdm(dl_valid):
            img, label = batch
            label = label.to(self._device)
            latent = torch.randn((bs, self._generator.dim_z)).to(self._device)

            with torch.no_grad():
                img = self._generator(latent, label)
                pred = self._classifier(img)

            diff = (label - pred)**2
            diffs.extend(diff.detach().cpu().numpy())
            labels.extend(label.detach().cpu().numpy())

        diffs = np.array(diffs)
        labels = np.array(labels)

        if build_hist:
            out_dir = Path(out_dir)
            out_dir.mkdir(parents=True, exist_ok=True)

            for i in range(n_out):
                column = self._columns[i]
                plt.figure()
                plt.title(f'{column}. Attribute control accuracy')
                plt.hist(diffs[:, i], bins=100)
                plt.savefig(out_dir / f'{column}.png', dpi=300)

        mean_diffs = np.mean(diffs, axis=0)

        result = {}
        for i in range(n_out):
            result[self._columns[i]] = mean_diffs[i]
        result['aggregated_attribute_accuracy'] = np.sum(diffs) / np.sum(
            labels)
        return result