Exemple #1
0
    def _test(metric_device):
        y_pred = torch.rand(offset * idist.get_world_size(),
                            3,
                            28,
                            28,
                            dtype=torch.float,
                            device=device)
        y = y_pred * 0.65

        def update(engine, i):
            return (
                y_pred[i * s + offset * rank:(i + 1) * s + offset * rank],
                y[i * s + offset * rank:(i + 1) * s + offset * rank],
            )

        engine = Engine(update)
        SSIM(data_range=1.0, device=metric_device).attach(engine, "ssim")

        data = list(range(n_iters))
        engine.run(data=data, max_epochs=1)

        assert "ssim" in engine.state.metrics
        res = engine.state.metrics["ssim"]

        np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
        np_true = np_pred * 0.65
        true_res = ski_ssim(np_pred,
                            np_true,
                            win_size=11,
                            multichannel=True,
                            gaussian_weights=True,
                            data_range=1.0)

        assert pytest.approx(res, abs=tol) == true_res

        engine = Engine(update)
        SSIM(data_range=1.0,
             gaussian=False,
             kernel_size=7,
             device=metric_device).attach(engine, "ssim")

        data = list(range(n_iters))
        engine.run(data=data, max_epochs=1)

        assert "ssim" in engine.state.metrics
        res = engine.state.metrics["ssim"]

        np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
        np_true = np_pred * 0.65
        true_res = ski_ssim(np_pred,
                            np_true,
                            win_size=7,
                            multichannel=True,
                            gaussian_weights=False,
                            data_range=1.0)

        assert pytest.approx(res, abs=tol) == true_res
Exemple #2
0
def test_ssim(device, shape, kernel_size, gaussian, use_sample_covariance):
    y_pred = torch.rand(shape, device=device)
    y = y_pred * 0.8

    sigma = 1.5
    data_range = 1.0
    ssim = SSIM(data_range=data_range, sigma=sigma, device=device)
    ssim.update((y_pred, y))
    ignite_ssim = ssim.compute()

    skimg_pred = y_pred.cpu().numpy()
    skimg_y = skimg_pred * 0.8
    skimg_ssim = ski_ssim(
        skimg_pred,
        skimg_y,
        win_size=kernel_size,
        sigma=sigma,
        channel_axis=1,
        gaussian_weights=gaussian,
        data_range=data_range,
        use_sample_covariance=use_sample_covariance,
    )

    assert isinstance(ignite_ssim, float)
    assert np.allclose(ignite_ssim, skimg_ssim, atol=7e-5)
Exemple #3
0
def test_ssim():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    ssim = SSIM(data_range=1.0, device=device)
    y_pred = torch.rand(16, 3, 64, 64, device=device)
    y = y_pred * 0.65
    ssim.update((y_pred, y))

    np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
    np_y = np_pred * 0.65
    np_ssim = ski_ssim(np_pred,
                       np_y,
                       win_size=11,
                       multichannel=True,
                       gaussian_weights=True,
                       data_range=1.0)

    assert isinstance(ssim.compute(), torch.Tensor)
    assert torch.allclose(ssim.compute(),
                          torch.tensor(np_ssim,
                                       dtype=torch.float64,
                                       device=device),
                          atol=1e-4)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    ssim = SSIM(data_range=1.0, gaussian=False, kernel_size=7, device=device)
    y_pred = torch.rand(16, 3, 227, 227, device=device)
    y = y_pred * 0.65
    ssim.update((y_pred, y))

    np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
    np_y = np_pred * 0.65
    np_ssim = ski_ssim(np_pred,
                       np_y,
                       win_size=7,
                       multichannel=True,
                       gaussian_weights=False,
                       data_range=1.0)

    assert isinstance(ssim.compute(), torch.Tensor)
    assert torch.allclose(ssim.compute(),
                          torch.tensor(np_ssim,
                                       dtype=torch.float64,
                                       device=device),
                          atol=1e-4)
Exemple #4
0
def test_ssim(size, channel, plus, multichannel):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    pred = torch.rand(1, channel, size, size, device=device)
    target = pred + plus
    ssim_idx = ssim(pred, target)
    np_pred = np.random.rand(size, size, channel)
    if multichannel is False:
        np_pred = np_pred[:, :, 0]
    np_target = np.add(np_pred, plus)
    sk_ssim_idx = ski_ssim(np_pred, np_target, win_size=11, multichannel=multichannel, gaussian_weights=True)
    assert torch.allclose(ssim_idx, torch.tensor(sk_ssim_idx, dtype=torch.float, device=device), atol=1e-2, rtol=1e-2)

    ssim_idx = ssim(pred, pred)
    assert torch.allclose(ssim_idx, torch.tensor(1.0, device=device))
Exemple #5
0
def test_ssim(size, channel, coef, multichannel):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    pred = torch.rand(size, channel, size, size, device=device)
    target = pred * coef
    ssim_idx = ssim(pred, target, data_range=1.0)
    np_pred = pred.permute(0, 2, 3, 1).cpu().numpy()
    if multichannel is False:
        np_pred = np_pred[:, :, :, 0]
    np_target = np.multiply(np_pred, coef)
    sk_ssim_idx = ski_ssim(
        np_pred, np_target, win_size=11, multichannel=multichannel, gaussian_weights=True, data_range=1.0
    )
    assert torch.allclose(ssim_idx, torch.tensor(sk_ssim_idx, dtype=torch.float, device=device), atol=1e-4)

    ssim_idx = ssim(pred, pred)
    assert torch.allclose(ssim_idx, torch.tensor(1.0, device=device))
Exemple #6
0
def _test_ssim(y_pred, y, data_range, kernel_size, sigma, gaussian, use_sample_covariance, device):
    atol = 7e-5
    ssim = SSIM(data_range=data_range, sigma=sigma, device=device)
    ssim.update((y_pred, y))
    ignite_ssim = ssim.compute()

    skimg_pred = y_pred.cpu().numpy()
    skimg_y = skimg_pred * 0.8
    skimg_ssim = ski_ssim(
        skimg_pred,
        skimg_y,
        win_size=kernel_size,
        sigma=sigma,
        channel_axis=1,
        gaussian_weights=gaussian,
        data_range=data_range,
        use_sample_covariance=use_sample_covariance,
    )

    assert isinstance(ignite_ssim, torch.Tensor)
    assert ignite_ssim.dtype == torch.float64
    assert ignite_ssim.device == torch.device(device)
    assert np.allclose(ignite_ssim.cpu().numpy(), skimg_ssim, atol=atol)
Exemple #7
0
def test(model, valid_loader, configs, epoch, folder, is_sunspots):
    print('====>  ', 'test for epoch  %s' % str(epoch),
          datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    res_path = os.path.join(folder, str(epoch))
    if not os.path.exists(res_path):
        os.mkdir(res_path)
    avg_mse = 0
    img_mse, ssim = [], []
    batch_num = len(valid_loader)
    test_size = batch_num * configs.batch_size
    for i in range(configs.predict_length):
        img_mse.append(0)
        ssim.append(0)

    with torch.no_grad():
        real_input_flag = generator_input_flag(configs)
        for idx, (test_dat, names) in enumerate(valid_loader):
            img_gen = model.test(test_dat, real_input_flag)
            img_gen = img_gen.transpose(0, 1, 3, 4, 2)  # 输出序列 [N S H W C]
            img_seq = test_dat.detach().cpu().numpy().transpose(
                0, 1, 3, 4, 2)  # 整个输入序列 [N S H W C]
            output_length = configs.predict_length
            img_gen_length = img_gen.shape[1]
            img_out = img_gen[:, -output_length:]

            # MSE per frame
            for i in range(output_length):
                x = img_seq[:, i + configs.input_length, :, :, :]
                gx = img_out[:, i, :, :, :]
                # 对输出数据clip
                if is_sunspots:
                    gx = np.clip(gx, a_min=0, a_max=1.0)
                else:
                    gx = np.clip(gx, a_min=-1., a_max=1.0)
                mse = np.square(x - gx).mean()
                img_mse[i] += mse
                avg_mse += mse

                if is_sunspots:
                    real_frm = np.uint8(x * 255)
                    pred_frm = np.uint8(gx * 255)
                else:
                    real_frm = np.uint8(x * 127.5 + 127.5)
                    pred_frm = np.uint8(gx * 127.5 + 127.5)

                for b in range(configs.batch_size):
                    score, _ = ski_ssim(pred_frm[b],
                                        real_frm[b],
                                        full=True,
                                        multichannel=True)
                    ssim[i] += score

            # save prediction examples
            if idx <= configs.num_save_samples or (not configs.is_training):
                path = os.path.join(res_path, names[0][0])
                if not os.path.exists(path):
                    os.mkdir(path)
                name_list = [_[0] for _ in names]
                np.savez(os.path.join(res_path, '%s.npz' % names[0][0]),
                         inputs=img_seq[0],
                         preds=img_gen[0],
                         names=name_list)
                for i in range(configs.input_length + configs.predict_length):
                    name = 'gt' + str(i + 1) + '.png'
                    file_name = os.path.join(path, name)
                    if is_sunspots:
                        img_gt = np.uint8(img_seq[0, i, :, :, :] * 255)
                    else:
                        img_gt = np.uint8(img_seq[0, i, :, :, :] * 127.5 +
                                          127.5)

                    cv2.imwrite(file_name, img_gt)
                for i in range(img_gen_length):
                    name = 'pd' + str(i + 2) + '.png'
                    file_name = os.path.join(path, name)
                    img_pd = img_gen[0, i, :, :, :]
                    if is_sunspots:
                        img_pd = np.clip(img_pd, a_min=0, a_max=1.0)
                    else:
                        img_pd = np.clip(img_pd, a_min=-1., a_max=1.0)
                    if is_sunspots:
                        img_pd = np.uint8(img_pd * 255)
                    else:
                        img_pd = np.uint8(img_pd * 127.5 + 127.5)

                    cv2.imwrite(file_name, img_pd)

    avg_mse = avg_mse / test_size
    img_mse = [_ / test_size for _ in img_mse]
    print('mse of avg: ' + str(avg_mse))
    print('mse of seq: ' + str(img_mse))

    ssim = [_ / test_size for _ in ssim]
    avg_ssim = np.mean(ssim)
    print('ssim of avg: ' + str(avg_ssim))
    print('ssim of seq: ' + str(ssim))

    f = open(os.path.join(configs.save_dir, 'Metric.txt'), 'a')
    writer = csv.writer(f, lineterminator='\n')
    metric = [epoch, avg_mse] + img_mse + ssim
    writer.writerow(metric)

    return {'avg_mse': avg_mse, 'avg_ssim': avg_ssim}
Exemple #8
0
                                              (epoch, ti)),
                                          padding=0,
                                          normalize=True)
                    # Calculate SSIM and PSNR
                    fakeB = fakeB.data.cpu().numpy()
                    realB = realB.data.cpu().numpy()
                    fakeB = fakeB.transpose(
                        (0, 2, 3, 1))  # N C H W  ---> N H W C
                    realB = realB.transpose(
                        (0, 2, 3, 1))  # N C H W  ---> N H W C
                    fakeB = np.clip(fakeB, a_min=0.0, a_max=1.0)
                    realB = np.clip(realB, a_min=0.0, a_max=1.0)

                    for _bti in range(fakeB.shape[0]):
                        per_fakeB = fakeB[_bti]
                        per_realB = realB[_bti]
                        test_ssim += ski_ssim(per_fakeB,
                                              per_realB,
                                              data_range=1,
                                              multichannel=True)
                        test_psnr += ski_psnr(per_realB,
                                              per_fakeB,
                                              data_range=1)
                        test_ite += 1

                test_psnr /= test_ite
                test_ssim /= test_ite
                print('     Valid PSNR: {:.4f}'.format(test_psnr))
                print('     Valid SSIM: {:.4f}'.format(test_ssim))
                print('------------------------')