예제 #1
0
def msssim_l1_loss(pred, tar):
    return 0.7 * msssimloss(
        rotate_real(pred)[:, 0:1, ...],
        rotate_real(tar)[:, 0:1, ...],
    ) + 0.3 * maeloss(
        rotate_real(pred)[:, 0:1, ...],
        rotate_real(tar)[:, 0:1, ...],
    )
예제 #2
0
 def evaluate_dataset(self, data, max_vols=None):
     """ Evaluates performance measures per volume. """
     logging = pd.DataFrame(
         index=range(data.volumes),
         columns=["nmse", "psnr", "ssim"],
     )
     num = (min(data.volumes, max_vols)
            if max_vols is not None else data.volumes)
     t = tqdm(range(num))
     for k in t:
         lo, hi = data.get_slices_in_volume(k)
         pred_list, tar_list = [], []
         for sl in range(lo, hi):
             slice_data = data[sl]
             if len(slice_data) == 3:
                 inp, aux, tar = slice_data
                 inp = inp.to(self.device).unsqueeze(0)
                 aux = aux.to(self.device).unsqueeze(0)
                 tar = tar.to(self.device).unsqueeze(0)
                 pred = self.forward((inp, aux))
             else:
                 inp, tar = slice_data
                 inp = inp.to(self.device).unsqueeze(0)
                 tar = tar.to(self.device).unsqueeze(0)
                 pred = self.forward(inp)
             # make complex signals real if necessary
             if tar.shape[-3] == 2:
                 tar = rotate_real(tar)[..., 0:1, :, :]
             if pred.shape[-3] == 2:
                 pred = rotate_real(pred)[..., 0:1, :, :]
             tar_list.append(tar.detach().cpu())
             pred_list.append(pred.detach().cpu())
         tar_np = torch.cat(tar_list, dim=0).squeeze(1).numpy()
         pred_np = torch.cat(pred_list, dim=0).squeeze(1).numpy()
         logging.iloc[k] = {
             "nmse": evaluate.nmse(tar_np, pred_np),
             "psnr": evaluate.psnr(tar_np, pred_np),
             "ssim": evaluate.ssim(tar_np, pred_np),
         }
     print(
         pd.DataFrame({
             "min": logging.min(),
             "mean": logging.mean(),
             "max": logging.max(),
         }))
     return logging
예제 #3
0
 def __call__(self, imgs):
     return tuple(
         [
             rotate_real(img)[..., 0:1, :, :]
             if img.shape[-3] == 2
             else torch.abs(img)
             for img in imgs
         ]
     )
예제 #4
0
def grid_search(x, y, rec_func, grid):
    """ Grid search utility for tuning hyper-parameters. """
    err_min = np.inf
    grid_param = None

    grid_shape = [len(val) for val in grid.values()]
    err = torch.zeros(grid_shape)
    err_psnr = torch.zeros(grid_shape)
    err_ssim = torch.zeros(grid_shape)

    for grid_val, nidx in zip(itertools.product(*grid.values()),
                              np.ndindex(*grid_shape)):
        grid_param_cur = dict(zip(grid.keys(), grid_val))
        print(
            "Current grid parameters (" + str(list(nidx)) + " / " +
            str(grid_shape) + "): " + str(grid_param_cur),
            flush=True,
        )
        x_rec = rec_func(y, **grid_param_cur)
        err[nidx], _ = l2_error(x_rec, x, relative=True, squared=False)
        err_psnr[nidx] = psnr(
            rotate_real(x_rec)[:, 0:1, ...],
            rotate_real(x)[:, 0:1, ...],
            data_range=rotate_real(x)[:, 0:1, ...].max(),
            reduction="mean",
        )
        err_ssim[nidx] = ssim(
            rotate_real(x_rec)[:, 0:1, ...],
            rotate_real(x)[:, 0:1, ...],
            data_range=rotate_real(x)[:, 0:1, ...].max(),
            size_average=True,
        )
        print("Rel. recovery error: {:1.2e}".format(err[nidx]), flush=True)
        print("PSNR: {:.2f}".format(err_psnr[nidx]), flush=True)
        print("SSIM: {:.2f}".format(err_ssim[nidx]), flush=True)
        if err[nidx] < err_min:
            grid_param = grid_param_cur
            err_min = err[nidx]

    return grid_param, err_min, err, err_psnr, err_ssim
                Y_0_s,
                store_data=True,
                keep_init=keep_init,
                err_measure=err_measure,
            )

            (
                results.loc[idx].X_adv_err[:, s],
                idx_max_adv_err,
            ) = X_adv_err_cur.max(dim=1)
            results.loc[idx].X_ref_err[:, s] = X_ref_err_cur.mean(dim=1)

            for idx_noise in range(len(noise_rel)):
                idx_max = idx_max_adv_err[idx_noise]
                results.loc[idx].X_adv_psnr[idx_noise, s] = psnr(
                    rotate_real(X_adv_cur[idx_noise, ...])[idx_max, 0:1, ...],
                    rotate_real(X_0_s.cpu())[0, 0:1, ...],
                    data_range=4.5,
                    reduction="none",
                )  # normalization as in example-script
                results.loc[idx].X_ref_psnr[idx_noise, s] = psnr(
                    rotate_real(X_ref_cur[idx_noise, ...])[:, 0:1, ...],
                    rotate_real(X_0_s.cpu())[:, 0:1, ...],
                    data_range=4.5,
                    reduction="mean",
                )  # normalization as in example-script
                results.loc[idx].X_adv_ssim[idx_noise, s] = ssim(
                    rotate_real(X_adv_cur[idx_noise, ...])[idx_max, 0:1, ...],
                    rotate_real(X_0_s.cpu())[0, 0:1, ...],
                    data_range=4.5,
                    size_average=False,
예제 #6
0
def loss_func(pred, tar):
    return (mseloss(
        rotate_real(pred)[:, 0:1, ...],
        rotate_real(tar)[:, 0:1, ...],
    ) / pred.shape[0])
예제 #7
0
def _complexloss(reference, prediction):
    loss = mseloss(
        rotate_real(reference)[:, 0:1, ...],
        rotate_real(prediction)[:, 0:1, ...],
    )
    return loss
    torchvision.transforms.Compose([
        CropOrPadAndResimulate((320, 320)),
        Flatten(0, -3),
        Normalize(reduction="mean", use_target=True),
    ], ),
}
test_data = AlmostFixedMaskDataset
test_data = test_data("val", **test_data_params)

lo, hi = test_data.get_slices_in_volume(sample_vol)
print("volume slices from {} to {}, selected {}".format(
    lo, hi, lo + sample_sl))
X_VOL = to_complex(
    torch.stack([test_data[sl_idx][2] for sl_idx in range(lo, hi)],
                dim=0)).to(device)
X_MAX = rotate_real(X_VOL)[:, 0:1, ...].max().cpu()
X_0 = to_complex(test_data[lo + sample_sl][2].to(device)).unsqueeze(0)
X_0 = X_0.repeat(it_init, *((X_0.ndim - 1) * (1, )))
Y_0 = cfg_rob.OpA(X_0)

# set range for plotting and similarity indices
v_min = 0.05
v_max = 4.50
print("Pixel values between {} and {}".format(v_min, v_max))

# create result table and load existing results from file
results = pd.DataFrame(columns=[
    "name",
    "X_adv_err",
    "X_ref_err",
    "X_adv_psnr",
예제 #9
0
                "; Noise rel {}/{}".format(idx_noise + 1, len(noise_rel)) +
                " (= {:1.3f})".format(noise_rel[idx_noise].item()),
                flush=True,
            )

            noise_level = noise_rel[idx_noise] * Y_0.norm(
                p=2, dim=(-2, -1), keepdim=True)
            Y = noise_type(Y_0, noise_level)
            X = method.reconstr(Y, noise_rel[idx_noise])

            print(((Y - Y_0).norm(p=2, dim=(-2, -1)) /
                   (Y_0).norm(p=2, dim=(-2, -1))).mean())

            results.loc[idx].X_err[idx_noise, ...] = err_measure(X, X_0)
            results.loc[idx].X_psnr[idx_noise, ...] = psnr(
                torch.clamp(rotate_real(X.cpu())[:, 0:1, ...], v_min, v_max),
                torch.clamp(rotate_real(X_0.cpu())[:, 0:1, ...], v_min, v_max),
                data_range=v_max - v_min,
                reduction="none",
            )
            results.loc[idx].X_ssim[idx_noise, ...] = ssim(
                torch.clamp(rotate_real(X.cpu())[:, 0:1, ...], v_min, v_max),
                torch.clamp(rotate_real(X_0.cpu())[:, 0:1, ...], v_min, v_max),
                data_range=v_max - v_min,
                size_average=False,
            )
            results.loc[idx].X[idx_noise, ...] = X[0:1, ...].cpu()
            results.loc[idx].Y[idx_noise, ...] = Y[0:1, ...].cpu()

# save results
for idx in results.index:
예제 #10
0
    torchvision.transforms.Compose([
        CropOrPadAndResimulate((320, 320)),
        Flatten(0, -3),
        Normalize(reduction="mean", use_target=True),
    ], ),
}
test_data = AlmostFixedMaskDataset
test_data = test_data("val", **test_data_params)

lo, hi = test_data.get_slices_in_volume(sample_vol)
print("volume slices from {} to {}, selected {}".format(
    lo, hi, lo + sample_sl))
X_VOL = to_complex(
    torch.stack([test_data[sl_idx][2] for sl_idx in range(lo, hi)],
                dim=0)).to(device)
X_MAX = rotate_real(X_VOL)[:, 0:1, ...].max().cpu()
X_0 = to_complex(test_data[lo + sample_sl][2].to(device)).unsqueeze(0)
print(X_0.min(), X_0.max())
P_0 = _perturbation(X_0.shape[-2:])
X_0 = X_0 + P_0
print(X_0.min(), X_0.max())
Y_0 = cfg_rob.OpA(X_0)

# create result table and load existing results from file
results = pd.DataFrame(columns=["name", "X_err", "X_psnr", "X_ssim", "X", "Y"])
results.name = methods.index
results = results.set_index("name")
# load existing results from file
if os.path.isfile(save_results):
    results_save = pd.read_pickle(save_results)
    for idx in results_save.index:
예제 #11
0
                    " (= {:1.3f})".format(noise_rel[idx_noise].item()),
                    flush=True,
                )

                noise_level = noise_rel[idx_noise] * Y_0_s.norm(
                    p=2, dim=(-2, -1), keepdim=True)
                Y = noise_type(Y_0_s, noise_level)
                X = method.reconstr(Y, noise_rel[idx_noise])

                print(((Y - Y_0_s).norm(p=2, dim=(-2, -1)) /
                       (Y_0_s).norm(p=2, dim=(-2, -1))).mean())

                results.loc[idx].X_err[idx_noise,
                                       s] = err_measure(X, X_0_s).mean()
                results.loc[idx].X_psnr[idx_noise, s] = psnr(
                    rotate_real(X.cpu())[:, 0:1, ...],
                    rotate_real(X_0_s.cpu())[:, 0:1, ...],
                    data_range=4.5,
                    reduction="mean",
                )  # normalization as in ex-script
                results.loc[idx].X_ssim[idx_noise, s] = ssim(
                    rotate_real(X.cpu())[:, 0:1, ...],
                    rotate_real(X_0_s.cpu())[:, 0:1, ...],
                    data_range=4.5,
                    size_average=True,
                )  # normalization as in ex-script

# save results
for idx in results.index:
    results_save.loc[idx] = results.loc[idx]
os.makedirs(save_path, exist_ok=True)