def _compute_fid_score(self) -> float: """Computes FID score for the dataset Returns: float: FID score """ # load dataset path = self._config['dataset']['path'] anno = self._config['dataset']['anno'] # all images should be resized to 299 for the network, used for to calculate FID size = 299 # updated version of MakeDataLoader is used, where size of the image can be changes # and custom paths to data and annotations can be passed directly make_dl = MakeDataLoader(path, anno, size, augmented=False) ds_test = make_dl.dataset_test ds_valid = make_dl.dataset_valid n_samples = len(ds_test) fid_func = get_fid_fn(ds_test, ds_valid, self._device, n_samples) with torch.no_grad(): fid_score = fid_func(self._generator) return fid_score
def _get_dl(self) -> DataLoader: """Creates infinite dataloader from valid and test sets of images Returns: DataLoader: created dataloader """ bs = self._config['batch_size'] n_workers = self._config['n_workers'] # load dataset path = self._config['dataset']['path'] anno = self._config['dataset']['anno'] size = self._config['dataset']['size'] make_dl = MakeDataLoader(path, anno, size, augmented=True) ds_valid = make_dl.dataset_valid ds_test = make_dl.dataset_test ds = ConcatDataset([ds_valid, ds_test]) dl = infinite_loader( DataLoader(ds, bs, True, num_workers=n_workers, drop_last=True)) return dl
def _compute_inception_score(self) -> float: """Computes inception score (IS) for the model Returns: float: inception score (IS) """ batch_size = self._config['batch_size'] # load dataset path = self._config['dataset']['path'] anno = self._config['dataset']['anno'] size = self._config['dataset']['size'] make_dl = MakeDataLoader(path, anno, size, augmented=False) ds = make_dl.dataset_valid n_samples = len(ds) dataset = GANDataset(self._generator, ds, self._device, n_samples) score = inception_score(dataset, batch_size=batch_size, resize=True)[0] return score
def _chamfer_distance(self, encoder_type: str = 'simclr') -> float: """Computes Chamfer distance between real and generated samples Args: encoder_type: type of encoder to use. Choices: simclr, ae Returns: float: Chamfer distance """ if encoder_type not in ['simclr', 'ae']: raise ValueError('Incorrect encoder') if encoder_type == 'simclr': encoder = self._encoder else: encoder = Encoder().to(self._device).eval() ckpt_encoder = torch.load(self._config['eval']['path_encoder']) encoder.load_state_dict(ckpt_encoder) bs = self._config['batch_size'] path = self._config['dataset']['path'] anno = self._config['dataset']['anno'] size = self._config['dataset']['size'] make_dl = MakeDataLoader(path, anno, size, augmented=True) dl_valid = make_dl.get_data_loader_valid(bs) dl_test = make_dl.get_data_loader_test(bs) embeddings_real = [] embeddings_gen = [] for batch_val, batch_test in zip(dl_valid, dl_test): img, _ = batch_test _, lbl = batch_val img = img.to(self._device) img = (img - 0.5) / 0.5 # renormalize s = img.shape[0] lbl = lbl.to(self._device) latent = torch.randn((s, self._generator.dim_z)).to(self._device) with torch.no_grad(): img_gen = self._generator(latent, lbl) img_gen = (img_gen - 0.5) / 0.5 # renormalize h, _ = encoder(img) h_gen, _ = encoder(img_gen) embeddings_real.extend(h.cpu().numpy()) embeddings_gen.extend(h_gen.cpu().numpy()) embeddings_real = np.array(embeddings_real, dtype=np.float32) embeddings_gen = np.array(embeddings_gen, dtype=np.float32) embeddings = np.concatenate((embeddings_real, embeddings_gen)) tsne_emb = TSNE(n_components=3, n_jobs=16).fit_transform(embeddings) n = len(tsne_emb) tsne_real = np.array(tsne_emb[:n // 2, ], dtype=np.float32) tsne_fake = np.array(tsne_emb[n // 2:, ], dtype=np.float32) tsne_real = torch.from_numpy(tsne_real).unsqueeze(0) tsne_fake = torch.from_numpy(tsne_fake).unsqueeze(0) chamfer_dist = ChamferDistance() return chamfer_dist(tsne_real, tsne_fake).detach().item()
def _compute_ssl_fid(self, encoder_type: str = 'simclr') -> float: """Computes FID on SSL features Args: encoder_type: type of encoder to use. Choices: simclr, ae Returns: float: FID """ if encoder_type not in ['simclr', 'ae']: raise ValueError('Incorrect encoder') if encoder_type == 'simclr': encoder = self._encoder else: encoder = Encoder().to(self._device).eval() ckpt_encoder = torch.load(self._config['eval']['path_encoder']) encoder.load_state_dict(ckpt_encoder) bs = self._config['batch_size'] path = self._config['dataset']['path'] anno = self._config['dataset']['anno'] size = self._config['dataset']['size'] make_dl = MakeDataLoader(path, anno, size, augmented=False) dl_valid = make_dl.get_data_loader_valid(bs) dl_test = make_dl.get_data_loader_test(bs) # compute activations activations_real = [] activations_fake = [] for batch_val, batch_test in zip(dl_valid, dl_test): img, _ = batch_test _, lbl = batch_val img = img.to(self._device) img = (img - 0.5) / 0.5 lbl = lbl.to(self._device) latent = torch.randn((bs, self._generator.dim_z)).to(self._device) with torch.no_grad(): img_gen = self._generator(latent, lbl) img_gen = (img_gen - 0.5) / 0.5 h, _ = encoder(img) h_gen, _ = encoder(img_gen) activations_real.extend(h.cpu().numpy()) activations_fake.extend(h_gen.cpu().numpy()) activations_real = np.array(activations_real) activations_fake = np.array(activations_fake) mu_real = np.mean(activations_real, axis=0) sigma_real = np.cov(activations_real, rowvar=False) mu_fake = np.mean(activations_fake, axis=0) sigma_fake = np.cov(activations_fake, rowvar=False) fletcher_distance = calculate_frechet_distance(mu_fake, sigma_fake, mu_real, sigma_real) return fletcher_distance
def _compute_kid(self, encoder_type: str = 'simclr') -> float: """Computes KID score Args: encoder_type: type of encoder to use. Choices: simclr, inception, ae Returns: float: KID score """ if encoder_type not in ['simclr', 'inception', 'ae']: raise ValueError('Incorrect encoder') if encoder_type == 'simclr': encoder = self._encoder elif encoder_type == 'inception': encoder = load_patched_inception_v3().to(self._device).eval() else: encoder = Encoder().to(self._device).eval() ckpt_encoder = torch.load(self._config['eval']['path_encoder']) encoder.load_state_dict(ckpt_encoder) bs = self._config['batch_size'] path = self._config['dataset']['path'] anno = self._config['dataset']['anno'] size = self._config['dataset']['size'] make_dl = MakeDataLoader(path, anno, size, augmented=False) dl_valid = make_dl.get_data_loader_valid(bs) dl_test = make_dl.get_data_loader_test(bs) features_real = [] features_gen = [] for batch_val, batch_test in zip(dl_valid, dl_test): _, lbl = batch_val img, _ = batch_test img = img.to(self._device) img = (img - 0.5) / 0.5 # renormalize lbl = lbl.to(self._device) latent = torch.randn((bs, self._generator.dim_z)).to(self._device) with torch.no_grad(): img_gen = self._generator(latent, lbl) if encoder_type == 'inception': if img.shape[2] != 299 or img.shape[3] != 299: img = torch.nn.functional.interpolate(img, size=(299, 299), mode='bicubic') img_gen = torch.nn.functional.interpolate(img_gen, size=(299, 299), mode='bicubic') img_gen = (img_gen - 0.5) / 0.5 h = encoder(img)[0].flatten(start_dim=1) h_gen = encoder(img_gen)[0].flatten(start_dim=1) else: h, _ = encoder(img) h_gen, _ = encoder(img_gen) features_real.extend(h.cpu().numpy()) features_gen.extend(h_gen.cpu().numpy()) features_real = np.array(features_real) features_gen = np.array(features_gen) m = 1000 # max subset size num_subsets = 100 n = features_real.shape[1] t = 0 for _ in range(num_subsets): x = features_gen[np.random.choice(features_gen.shape[0], m, replace=False)] y = features_real[np.random.choice(features_real.shape[0], m, replace=False)] a = (x @ x.T / n + 1)**3 + (y @ y.T / n + 1)**3 b = (x @ y.T / n + 1)**3 t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m kid = t / num_subsets / m return float(kid)
def _compute_geometric_distance(self, encoder_type: str = 'simclr') -> float: """Computes geometric distance between real and generated samples using features computed using SimCLR Args: encoder_type: type of encoder to use. Choices: simclr, ae Returns: float: geometric distance """ if encoder_type not in ['simclr', 'ae']: raise ValueError('Incorrect encoder') if encoder_type == 'simclr': encoder = self._encoder else: encoder = Encoder().to(self._device).eval() ckpt_encoder = torch.load(self._config['eval']['path_encoder']) encoder.load_state_dict(ckpt_encoder) loss = SamplesLoss("sinkhorn", p=2, blur=0.05, scaling=0.8, backend="tensorized") bs = self._config['batch_size'] # load dataset path = self._config['dataset']['path'] anno = self._config['dataset']['anno'] size = self._config['dataset']['size'] make_dl = MakeDataLoader(path, anno, size, augmented=True) dl_valid = make_dl.get_data_loader_valid(bs) dl_test = make_dl.get_data_loader_test(bs) embeddings_real = [] embeddings_gen = [] for batch_val, batch_test in zip(dl_valid, dl_test): _, lbl = batch_val img, _ = batch_test img = img.to(self._device) img = (img - 0.5) / 0.5 # renormalize image lbl = lbl.to(self._device) latent = torch.randn((bs, self._generator.dim_z)).to(self._device) with torch.no_grad(): img_gen = self._generator(latent, lbl) h, _ = encoder(img) h_gen, _ = encoder(img_gen) embeddings_real.extend(h.detach().cpu()) embeddings_gen.extend(h_gen.detach().cpu()) embeddings_real = torch.stack(embeddings_real) embeddings_gen = torch.stack(embeddings_gen) distance = loss(embeddings_real, embeddings_gen) return distance.detach().cpu().item()
def _attribute_control_accuracy(self, build_hist: bool = True, out_dir: PathOrStr = './images') -> Dict: """Computes attribute control accuracy Args: build_hist: if True, the histogram of differences for each label will be built and saved out_dir: path to directory, where to save histogram images Returns: Dict: attribute control accuracy for each label """ # load dataset bs = self._config['batch_size'] path = self._config['dataset']['path'] anno = self._config['dataset']['anno'] size = self._config['dataset']['size'] make_dl = MakeDataLoader(path, anno, size, augmented=True) dl_valid = make_dl.get_data_loader_valid(batch_size=bs) n_out = self._config['dataset']['n_out'] diffs = [] labels = [] for batch in tqdm(dl_valid): img, label = batch label = label.to(self._device) latent = torch.randn((bs, self._generator.dim_z)).to(self._device) with torch.no_grad(): img = self._generator(latent, label) pred = self._classifier(img) diff = (label - pred)**2 diffs.extend(diff.detach().cpu().numpy()) labels.extend(label.detach().cpu().numpy()) diffs = np.array(diffs) labels = np.array(labels) if build_hist: out_dir = Path(out_dir) out_dir.mkdir(parents=True, exist_ok=True) for i in range(n_out): column = self._columns[i] plt.figure() plt.title(f'{column}. Attribute control accuracy') plt.hist(diffs[:, i], bins=100) plt.savefig(out_dir / f'{column}.png', dpi=300) mean_diffs = np.mean(diffs, axis=0) result = {} for i in range(n_out): result[self._columns[i]] = mean_diffs[i] result['aggregated_attribute_accuracy'] = np.sum(diffs) / np.sum( labels) return result