示例#1
0
def test_ssim_invalid_inputs(pred, target, kernel, sigma):
    pred_t = torch.rand(pred)
    target_t = torch.rand(target, dtype=torch.float64)
    with pytest.raises(TypeError):
        ssim(pred_t, target_t)

    pred = torch.rand(pred)
    target = torch.rand(target)
    with pytest.raises(ValueError):
        ssim(pred, target, kernel, sigma)
示例#2
0
def test_ssim(size, channel, plus, multichannel):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    pred = torch.rand(1, channel, size, size, device=device)
    target = pred + plus
    ssim_idx = ssim(pred, target)
    np_pred = np.random.rand(size, size, channel)
    if multichannel is False:
        np_pred = np_pred[:, :, 0]
    np_target = np.add(np_pred, plus)
    sk_ssim_idx = ski_ssim(np_pred, np_target, win_size=11, multichannel=multichannel, gaussian_weights=True)
    assert torch.allclose(ssim_idx, torch.tensor(sk_ssim_idx, dtype=torch.float, device=device), atol=1e-2, rtol=1e-2)

    ssim_idx = ssim(pred, pred)
    assert torch.allclose(ssim_idx, torch.tensor(1.0, device=device))
示例#3
0
def test_ssim(size, channel, coef, multichannel):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    pred = torch.rand(size, channel, size, size, device=device)
    target = pred * coef
    ssim_idx = ssim(pred, target, data_range=1.0)
    np_pred = pred.permute(0, 2, 3, 1).cpu().numpy()
    if multichannel is False:
        np_pred = np_pred[:, :, :, 0]
    np_target = np.multiply(np_pred, coef)
    sk_ssim_idx = ski_ssim(
        np_pred, np_target, win_size=11, multichannel=multichannel, gaussian_weights=True, data_range=1.0
    )
    assert torch.allclose(ssim_idx, torch.tensor(sk_ssim_idx, dtype=torch.float, device=device), atol=1e-4)

    ssim_idx = ssim(pred, pred)
    assert torch.allclose(ssim_idx, torch.tensor(1.0, device=device))
示例#4
0
    def _test_step(self, batch):
        img, target = batch
        pred, kl = self(x=img)
        pred_imgs = [pred]

        # iterative refining of predictions
        for _ in range(self.refine_steps):
            pred, kl = self(x=img, y=pred)
            pred_imgs.append(pred)

        mse_loss = ((pred - target)**2).mean(dim=(1, 2, 3))
        loss = mse_loss + (self.hparams.kl_coeff * kl)
        loss = loss.mean()

        ssim_val = ssim(pred, target)

        logs = {
            "mse_loss": mse_loss.mean(),
            "kl": kl.mean(),
            "scaled_kl": self.hparams.kl_coeff * kl.mean(),
            "loss": loss,
            "ssim": ssim_val,
        }

        img_logs = {
            "input": img[0].squeeze(0),
            "pred_imgs": [pred[0] for pred in pred_imgs],
            "target": target[0]
        }

        return loss, logs, img_logs
示例#5
0
    def _train_step(self, batch):
        img, target = batch
        pred, kl = self(x=img, y=target)

        mse_loss = ((pred - target)**2).mean(dim=(1, 2, 3))
        loss = mse_loss + (self.hparams.kl_coeff * kl)
        loss = loss.mean()

        ssim_val = ssim(pred, target)

        logs = {
            "mse_loss": mse_loss.mean(),
            "kl": kl.mean(),
            "scaled_kl": self.hparams.kl_coeff * kl.mean(),
            "loss": loss,
            "ssim": ssim_val,
        }

        img_logs = {
            "input": img[0].squeeze(0),
            "pred": pred[0],
            "target": target[0]
        }

        return loss, logs, img_logs
def ssim_lightning(real, fake, return_per_frame=False, normalize_range=True):
    if real.dim() == 3:
        real = real[None, None]
        fake = fake[None, None]
    elif real.dim() == 4:
        real = real[None]
        fake = fake[None]

    if normalize_range:
        real = (real + 1.) / 2.
        fake = (fake + 1.) / 2.

    ssim_batch = PF.ssim(fake.reshape(-1, *fake.shape[2:]),
                         real.reshape(-1, *real.shape[2:])).cpu().numpy()

    if return_per_frame:
        ssim_per_frame = {
            i: PF.ssim(fake[:, i], real[:, i]).cpu().numpy()
            for i in range(real.shape[1])
        }

        return ssim_batch, ssim_per_frame

    return ssim_batch
示例#7
0
def iqa_metrics(data_loader, transform=None):
    metrics = {}
    if not 'SVHN' in args.dataset:
        Id_mat = torch.eye(n_dims, device=DEVICE).reshape(-1, *input_shape)
        metrics['l2-err'] = ((transform(Id_mat) - Id_mat).norm() /
                             Id_mat.norm()).item()

    if args.dataset == 'CIFAR10' or args.dataset == 'MNIST':
        metrics['PSNR'] = 0
        metrics['SSIM'] = 0
        # metrics['HaarPsi'] = 0

        for inputs, labels in data_loader:
            images = inputs
            restored = transform(images)

            images = utility.to_zero_one(images)
            restored = utility.to_zero_one(restored)

            # for image, restored_image in zip(
            #         images.permute(0, 2, 3, 1).squeeze().cpu().numpy(),
            #         restored.permute(0, 2, 3, 1).squeeze().cpu().numpy()):
            #     metrics['HaarPsi'] += (haar_psi_numpy(image, restored_image)[0]
            #                            / len(data_loader) / len(inputs))

            if images.shape[1] == 3:
                images = utility.rbg_to_luminance(images)
                restored = utility.rbg_to_luminance(restored)

            metrics['PSNR'] += utility.psnr(
                images, restored).mean().item() / len(data_loader)
            metrics['SSIM'] += (1 + ssim(images, restored).item()) / 2 / len(
                data_loader)  # default: mean, scale to [0, 1]

    if args.dataset == 'GMM':
        metrics['c-entropy'] = 0
        for inputs, labels in data_loader:
            metrics['c-entropy'] += dataset.cross_entropy(
                transform(inputs)).item() / len(data_loader)

    return metrics
    def step(self, batch):
        img, target = batch
        img = img.squeeze(1)
        pred = self(img)

        mse_loss = ((pred - target)**2).mean(dim=(1, 2, 3)).mean()

        ssim_val = ssim(pred, target)

        logs = {
            "mse_loss": mse_loss.mean(),
            "ssim": ssim_val,
        }

        img_logs = {
            "input": img[0].squeeze(0),
            "pred": pred[0],
            "target": target[0]
        }

        return mse_loss, logs, img_logs
示例#9
0
 def forward(self, pred, target):
     pred, target = map(self.to_slices, (pred, target))
     return ssim(pred, target, *self.ssim_args, **self.ssim_kwargs)
示例#10
0
def test_v1_5_metric_regress():
    ExplainedVariance.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        ExplainedVariance()

    MeanAbsoluteError.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        MeanAbsoluteError()

    MeanSquaredError.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        MeanSquaredError()

    MeanSquaredLogError.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        MeanSquaredLogError()

    target = torch.tensor([3, -0.5, 2, 7])
    preds = torch.tensor([2.5, 0.0, 2, 8])
    explained_variance._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        res = explained_variance(preds, target)
    assert torch.allclose(res, torch.tensor(0.9572), atol=1e-4)

    x = torch.tensor([0., 1, 2, 3])
    y = torch.tensor([0., 1, 2, 2])
    mean_absolute_error._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        assert mean_absolute_error(x, y) == 0.25

    mean_relative_error._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        assert mean_relative_error(x, y) == 0.125

    mean_squared_error._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        assert mean_squared_error(x, y) == 0.25

    mean_squared_log_error._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        res = mean_squared_log_error(x, y)
    assert torch.allclose(res, torch.tensor(0.0207), atol=1e-4)

    PSNR.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        PSNR()

    R2Score.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        R2Score()

    SSIM.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        SSIM()

    preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
    target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
    psnr._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        res = psnr(preds, target)
    assert torch.allclose(res, torch.tensor(2.5527), atol=1e-4)

    target = torch.tensor([3, -0.5, 2, 7])
    preds = torch.tensor([2.5, 0.0, 2, 8])
    r2score._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        res = r2score(preds, target)
    assert torch.allclose(res, torch.tensor(0.9486), atol=1e-4)

    preds = torch.rand([16, 1, 16, 16])
    target = preds * 0.75
    ssim._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        res = ssim(preds, target)
    assert torch.allclose(res, torch.tensor(0.9219), atol=1e-4)
示例#11
0
文件: gan.py 项目: reppertj/earworm
    def training_step(self, sgrams, batch_idx, optimizer_idx):
        batch_size = sgrams.shape[0]
        assert sgrams.shape[1:] == self.spectrogram_shape  # (N, 1, H, W)
        """
        Four classes with product of terms objective.
        See Dandi et al., "Generalized Adversarially Learned Inference"
        https://arxiv.org/pdf/2006.08089.pdf
        0: x, E(x)
        1: z, G(z)
        2: x, E(G(E(x)))
        3: G(E(G(z))), z
        """
        log_probs = [None] * 4
        # Class 0:
        true_latent_params = self.encoder(sgrams).view(
            -1, 2, self.hparams.latent_size)
        true_mu = true_latent_params[:, 0, :]
        true_std = true_latent_params[:, 1, :]
        true_latent_sample = self.sample(true_mu, true_std)
        logits = self.discriminator((sgrams, true_latent_sample))
        log_probs[0] = F.log_softmax(logits, dim=1)
        if optimizer_idx == 0:
            pt_loss = self.product_of_terms_loss(log_probs[0], 0)

        # Class 1:
        fake_latent = torch.randn(batch_size,
                                  self.hparams.latent_size,
                                  device=self.device)
        fake_sgrams = self.generator(fake_latent)
        logits = self.discriminator((fake_sgrams, fake_latent))
        log_probs[1] = F.log_softmax(logits, dim=1)
        if optimizer_idx == 0:
            pt_loss += self.product_of_terms_loss(log_probs[1], 1)

        # Class 2:
        true_reencoded_params = self.encoder(
            self.generator(true_latent_sample)).view(-1, 2,
                                                     self.hparams.latent_size)
        true_reencoded_mu = true_reencoded_params[:, 0, :]
        true_reencoded_std = true_reencoded_params[:, 1, :]
        true_reencoded_latent_sample = self.sample(true_reencoded_mu,
                                                   true_reencoded_std)
        logits = self.discriminator((sgrams, true_reencoded_latent_sample))
        log_probs[2] = F.log_softmax(logits, dim=1)
        if optimizer_idx == 0:
            pt_loss += self.product_of_terms_loss(log_probs[2], 2)

        # Class 3:
        fake_reencoded_params = self.encoder(fake_sgrams).view(
            -1, 2, self.hparams.latent_size)
        fake_reencoded_mu = fake_reencoded_params[:, 0, :]
        fake_reencoded_std = fake_reencoded_params[:, 1, :]
        fake_reencoded_latent_sample = self.sample(fake_reencoded_mu,
                                                   fake_reencoded_std)
        fake_regenerated_sgrams = self.generator(fake_reencoded_latent_sample)
        self.log("reconstruction-mse",
                 mean_squared_error(fake_regenerated_sgrams, sgrams))
        self.log("reconstruction-ssim", ssim(fake_regenerated_sgrams, sgrams))
        logits = self.discriminator((fake_regenerated_sgrams, fake_latent))
        log_probs[3] = F.log_softmax(logits, dim=1)
        # Encoder-Generator loss

        if optimizer_idx == 0:
            pt_loss += self.product_of_terms_loss(log_probs[3], 3)

        last_update = "D" if optimizer_idx == 0 else "EG"
        step_type = "t" if self.training else "v"
        self.log(f"x,E(x)|{step_type}|{last_update}",
                 torch.exp(log_probs[0][:, 0]).mean())
        self.log(f"z,G(z)|{step_type}|{last_update}",
                 torch.exp(log_probs[1][:, 1]).mean())
        self.log(f"x,E(G(E(x)))|{step_type}|{last_update}",
                 torch.exp(log_probs[2][:, 2]).mean())
        self.log(f"G(E(G(z))),z|{step_type}|{last_update}",
                 torch.exp(log_probs[3][:, 3]).mean())

        if optimizer_idx == 0:
            # Weight the objective containing m=4 log terms by 2/m=0.5, corresponding to a weight
            # of 1 for ALI's generator and discriminator objective, keeping the objectives
            # in a similar range for both optimizers
            pt_loss = pt_loss.mean() / 2
            self.log("eg_loss", pt_loss)
            return pt_loss.mul(torch.tensor(-1, device=self.device))

        # Discriminator loss

        if optimizer_idx == 1:
            d_loss = torch.zeros((batch_size, ),
                                 dtype=log_probs[0].dtype,
                                 device=self.device)
            for true_idx in range(4):
                target = torch.full((batch_size, ),
                                    true_idx,
                                    dtype=torch.long,
                                    device=self.device)
                # Maybe randomly swap labels instead
                d_loss += self.smoothing_cross_entropy_loss(
                    log_probs[true_idx], target)
                # d_loss += F.cross_entropy(logits[true_idx], target, reduction="none")

            d_loss = d_loss.mean()
            self.log("d_loss", d_loss)
            return d_loss.mean()  # Reduce across batch dimension
示例#12
0
 def forward(self, preds, target):
     preds = self.unnormalize(preds)
     target = self.unnormalize(target)
     return ssim(preds, target, data_range=self.scale)