def test_ssim_invalid_inputs(pred, target, kernel, sigma): pred_t = torch.rand(pred) target_t = torch.rand(target, dtype=torch.float64) with pytest.raises(TypeError): ssim(pred_t, target_t) pred = torch.rand(pred) target = torch.rand(target) with pytest.raises(ValueError): ssim(pred, target, kernel, sigma)
def test_ssim(size, channel, plus, multichannel): device = "cuda" if torch.cuda.is_available() else "cpu" pred = torch.rand(1, channel, size, size, device=device) target = pred + plus ssim_idx = ssim(pred, target) np_pred = np.random.rand(size, size, channel) if multichannel is False: np_pred = np_pred[:, :, 0] np_target = np.add(np_pred, plus) sk_ssim_idx = ski_ssim(np_pred, np_target, win_size=11, multichannel=multichannel, gaussian_weights=True) assert torch.allclose(ssim_idx, torch.tensor(sk_ssim_idx, dtype=torch.float, device=device), atol=1e-2, rtol=1e-2) ssim_idx = ssim(pred, pred) assert torch.allclose(ssim_idx, torch.tensor(1.0, device=device))
def test_ssim(size, channel, coef, multichannel): device = "cuda" if torch.cuda.is_available() else "cpu" pred = torch.rand(size, channel, size, size, device=device) target = pred * coef ssim_idx = ssim(pred, target, data_range=1.0) np_pred = pred.permute(0, 2, 3, 1).cpu().numpy() if multichannel is False: np_pred = np_pred[:, :, :, 0] np_target = np.multiply(np_pred, coef) sk_ssim_idx = ski_ssim( np_pred, np_target, win_size=11, multichannel=multichannel, gaussian_weights=True, data_range=1.0 ) assert torch.allclose(ssim_idx, torch.tensor(sk_ssim_idx, dtype=torch.float, device=device), atol=1e-4) ssim_idx = ssim(pred, pred) assert torch.allclose(ssim_idx, torch.tensor(1.0, device=device))
def _test_step(self, batch): img, target = batch pred, kl = self(x=img) pred_imgs = [pred] # iterative refining of predictions for _ in range(self.refine_steps): pred, kl = self(x=img, y=pred) pred_imgs.append(pred) mse_loss = ((pred - target)**2).mean(dim=(1, 2, 3)) loss = mse_loss + (self.hparams.kl_coeff * kl) loss = loss.mean() ssim_val = ssim(pred, target) logs = { "mse_loss": mse_loss.mean(), "kl": kl.mean(), "scaled_kl": self.hparams.kl_coeff * kl.mean(), "loss": loss, "ssim": ssim_val, } img_logs = { "input": img[0].squeeze(0), "pred_imgs": [pred[0] for pred in pred_imgs], "target": target[0] } return loss, logs, img_logs
def _train_step(self, batch): img, target = batch pred, kl = self(x=img, y=target) mse_loss = ((pred - target)**2).mean(dim=(1, 2, 3)) loss = mse_loss + (self.hparams.kl_coeff * kl) loss = loss.mean() ssim_val = ssim(pred, target) logs = { "mse_loss": mse_loss.mean(), "kl": kl.mean(), "scaled_kl": self.hparams.kl_coeff * kl.mean(), "loss": loss, "ssim": ssim_val, } img_logs = { "input": img[0].squeeze(0), "pred": pred[0], "target": target[0] } return loss, logs, img_logs
def ssim_lightning(real, fake, return_per_frame=False, normalize_range=True): if real.dim() == 3: real = real[None, None] fake = fake[None, None] elif real.dim() == 4: real = real[None] fake = fake[None] if normalize_range: real = (real + 1.) / 2. fake = (fake + 1.) / 2. ssim_batch = PF.ssim(fake.reshape(-1, *fake.shape[2:]), real.reshape(-1, *real.shape[2:])).cpu().numpy() if return_per_frame: ssim_per_frame = { i: PF.ssim(fake[:, i], real[:, i]).cpu().numpy() for i in range(real.shape[1]) } return ssim_batch, ssim_per_frame return ssim_batch
def iqa_metrics(data_loader, transform=None): metrics = {} if not 'SVHN' in args.dataset: Id_mat = torch.eye(n_dims, device=DEVICE).reshape(-1, *input_shape) metrics['l2-err'] = ((transform(Id_mat) - Id_mat).norm() / Id_mat.norm()).item() if args.dataset == 'CIFAR10' or args.dataset == 'MNIST': metrics['PSNR'] = 0 metrics['SSIM'] = 0 # metrics['HaarPsi'] = 0 for inputs, labels in data_loader: images = inputs restored = transform(images) images = utility.to_zero_one(images) restored = utility.to_zero_one(restored) # for image, restored_image in zip( # images.permute(0, 2, 3, 1).squeeze().cpu().numpy(), # restored.permute(0, 2, 3, 1).squeeze().cpu().numpy()): # metrics['HaarPsi'] += (haar_psi_numpy(image, restored_image)[0] # / len(data_loader) / len(inputs)) if images.shape[1] == 3: images = utility.rbg_to_luminance(images) restored = utility.rbg_to_luminance(restored) metrics['PSNR'] += utility.psnr( images, restored).mean().item() / len(data_loader) metrics['SSIM'] += (1 + ssim(images, restored).item()) / 2 / len( data_loader) # default: mean, scale to [0, 1] if args.dataset == 'GMM': metrics['c-entropy'] = 0 for inputs, labels in data_loader: metrics['c-entropy'] += dataset.cross_entropy( transform(inputs)).item() / len(data_loader) return metrics
def step(self, batch): img, target = batch img = img.squeeze(1) pred = self(img) mse_loss = ((pred - target)**2).mean(dim=(1, 2, 3)).mean() ssim_val = ssim(pred, target) logs = { "mse_loss": mse_loss.mean(), "ssim": ssim_val, } img_logs = { "input": img[0].squeeze(0), "pred": pred[0], "target": target[0] } return mse_loss, logs, img_logs
def forward(self, pred, target): pred, target = map(self.to_slices, (pred, target)) return ssim(pred, target, *self.ssim_args, **self.ssim_kwargs)
def test_v1_5_metric_regress(): ExplainedVariance.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): ExplainedVariance() MeanAbsoluteError.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): MeanAbsoluteError() MeanSquaredError.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): MeanSquaredError() MeanSquaredLogError.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): MeanSquaredLogError() target = torch.tensor([3, -0.5, 2, 7]) preds = torch.tensor([2.5, 0.0, 2, 8]) explained_variance._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): res = explained_variance(preds, target) assert torch.allclose(res, torch.tensor(0.9572), atol=1e-4) x = torch.tensor([0., 1, 2, 3]) y = torch.tensor([0., 1, 2, 2]) mean_absolute_error._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert mean_absolute_error(x, y) == 0.25 mean_relative_error._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert mean_relative_error(x, y) == 0.125 mean_squared_error._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert mean_squared_error(x, y) == 0.25 mean_squared_log_error._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): res = mean_squared_log_error(x, y) assert torch.allclose(res, torch.tensor(0.0207), atol=1e-4) PSNR.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): PSNR() R2Score.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): R2Score() SSIM.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): SSIM() preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) psnr._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): res = psnr(preds, target) assert torch.allclose(res, torch.tensor(2.5527), atol=1e-4) target = torch.tensor([3, -0.5, 2, 7]) preds = torch.tensor([2.5, 0.0, 2, 8]) r2score._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): res = r2score(preds, target) assert torch.allclose(res, torch.tensor(0.9486), atol=1e-4) preds = torch.rand([16, 1, 16, 16]) target = preds * 0.75 ssim._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): res = ssim(preds, target) assert torch.allclose(res, torch.tensor(0.9219), atol=1e-4)
def training_step(self, sgrams, batch_idx, optimizer_idx): batch_size = sgrams.shape[0] assert sgrams.shape[1:] == self.spectrogram_shape # (N, 1, H, W) """ Four classes with product of terms objective. See Dandi et al., "Generalized Adversarially Learned Inference" https://arxiv.org/pdf/2006.08089.pdf 0: x, E(x) 1: z, G(z) 2: x, E(G(E(x))) 3: G(E(G(z))), z """ log_probs = [None] * 4 # Class 0: true_latent_params = self.encoder(sgrams).view( -1, 2, self.hparams.latent_size) true_mu = true_latent_params[:, 0, :] true_std = true_latent_params[:, 1, :] true_latent_sample = self.sample(true_mu, true_std) logits = self.discriminator((sgrams, true_latent_sample)) log_probs[0] = F.log_softmax(logits, dim=1) if optimizer_idx == 0: pt_loss = self.product_of_terms_loss(log_probs[0], 0) # Class 1: fake_latent = torch.randn(batch_size, self.hparams.latent_size, device=self.device) fake_sgrams = self.generator(fake_latent) logits = self.discriminator((fake_sgrams, fake_latent)) log_probs[1] = F.log_softmax(logits, dim=1) if optimizer_idx == 0: pt_loss += self.product_of_terms_loss(log_probs[1], 1) # Class 2: true_reencoded_params = self.encoder( self.generator(true_latent_sample)).view(-1, 2, self.hparams.latent_size) true_reencoded_mu = true_reencoded_params[:, 0, :] true_reencoded_std = true_reencoded_params[:, 1, :] true_reencoded_latent_sample = self.sample(true_reencoded_mu, true_reencoded_std) logits = self.discriminator((sgrams, true_reencoded_latent_sample)) log_probs[2] = F.log_softmax(logits, dim=1) if optimizer_idx == 0: pt_loss += self.product_of_terms_loss(log_probs[2], 2) # Class 3: fake_reencoded_params = self.encoder(fake_sgrams).view( -1, 2, self.hparams.latent_size) fake_reencoded_mu = fake_reencoded_params[:, 0, :] fake_reencoded_std = fake_reencoded_params[:, 1, :] fake_reencoded_latent_sample = self.sample(fake_reencoded_mu, fake_reencoded_std) fake_regenerated_sgrams = self.generator(fake_reencoded_latent_sample) self.log("reconstruction-mse", mean_squared_error(fake_regenerated_sgrams, sgrams)) self.log("reconstruction-ssim", ssim(fake_regenerated_sgrams, sgrams)) logits = self.discriminator((fake_regenerated_sgrams, fake_latent)) log_probs[3] = F.log_softmax(logits, dim=1) # Encoder-Generator loss if optimizer_idx == 0: pt_loss += self.product_of_terms_loss(log_probs[3], 3) last_update = "D" if optimizer_idx == 0 else "EG" step_type = "t" if self.training else "v" self.log(f"x,E(x)|{step_type}|{last_update}", torch.exp(log_probs[0][:, 0]).mean()) self.log(f"z,G(z)|{step_type}|{last_update}", torch.exp(log_probs[1][:, 1]).mean()) self.log(f"x,E(G(E(x)))|{step_type}|{last_update}", torch.exp(log_probs[2][:, 2]).mean()) self.log(f"G(E(G(z))),z|{step_type}|{last_update}", torch.exp(log_probs[3][:, 3]).mean()) if optimizer_idx == 0: # Weight the objective containing m=4 log terms by 2/m=0.5, corresponding to a weight # of 1 for ALI's generator and discriminator objective, keeping the objectives # in a similar range for both optimizers pt_loss = pt_loss.mean() / 2 self.log("eg_loss", pt_loss) return pt_loss.mul(torch.tensor(-1, device=self.device)) # Discriminator loss if optimizer_idx == 1: d_loss = torch.zeros((batch_size, ), dtype=log_probs[0].dtype, device=self.device) for true_idx in range(4): target = torch.full((batch_size, ), true_idx, dtype=torch.long, device=self.device) # Maybe randomly swap labels instead d_loss += self.smoothing_cross_entropy_loss( log_probs[true_idx], target) # d_loss += F.cross_entropy(logits[true_idx], target, reduction="none") d_loss = d_loss.mean() self.log("d_loss", d_loss) return d_loss.mean() # Reduce across batch dimension
def forward(self, preds, target): preds = self.unnormalize(preds) target = self.unnormalize(target) return ssim(preds, target, data_range=self.scale)