Exemplo n.º 1
0
 def _process_data(self, kspace_data, gt_data):
     mask = self.mask_func(kspace_data.shape).expand_as(kspace_data)
     kspace_data = unprep_fft_channel(kspace_data.unsqueeze(0))
     mask = unprep_fft_channel(mask.unsqueeze(0))
     kspace_masked = kspace_data * mask
     gt_data = to_complex(gt_data)
     if self.keep_mask_as_func:
         mask = self.mask_func
     return kspace_masked, mask, gt_data
Exemplo n.º 2
0
def text_perturbation(shape):
    # create empty image
    img = Image.new("P", shape)

    # adapt font choice depending on font availability on your system
    system_fonts = font_manager.findSystemFonts(fontpaths=None, fontext="ttf")
    filtered_fonts = [
        font for font in system_fonts if "DejaVuSansMono.ttf" in font
    ]
    # draw text on image
    draw = ImageDraw.Draw(img)

    font = ImageFont.truetype(filtered_fonts[0], size=16)
    draw.text(
        (0.27 * shape[0], 0.350 * shape[1]),
        "I",
        255,
        font=font,
        stroke_width=0,
    )
    font = ImageFont.truetype(filtered_fonts[0], size=13)
    draw.text(
        (0.30 * shape[0], 0.360 * shape[1]),
        "E",
        255,
        font=font,
        stroke_width=0,
    )
    font = ImageFont.truetype(filtered_fonts[0], size=10)
    draw.text(
        (0.33 * shape[0], 0.3725 * shape[1]),
        "E",
        255,
        font=font,
        stroke_width=0,
    )
    font = ImageFont.truetype(filtered_fonts[0], size=7)
    draw.text(
        (0.355 * shape[0], 0.385 * shape[1]),
        "E",
        255,
        font=font,
        stroke_width=0,
    )

    t_img = (torch.tensor(
        np.array(img.getdata()).astype(np.float32).reshape(shape)) / 255)
    return to_complex(t_img.unsqueeze(0)).unsqueeze(0).to(device)
        Flatten(0, -3),
        Normalize(reduction="mean", use_target=True),
    ], ),
}
test_data = AlmostFixedMaskDataset
test_data = test_data("val", **test_data_params)

vols = range(30)
slices_in_vols = [test_data.get_slices_in_volume(vol_idx) for vol_idx in vols]
slices_selected = [
    range((lo + hi) // 2, (lo + hi) // 2 + 1) for lo, hi in slices_in_vols
]
samples = np.concatenate(slices_selected)

X_0 = torch.stack([test_data[s][2] for s in samples])
X_0 = to_complex(X_0.to(device))

Y_0 = torch.stack([test_data[s][0] for s in samples])
Y_0 = im2vec(to_complex(Y_0.to(device)))[..., im2vec(mask[0, 0, :, :].bool())]

# ----- noise setup --------
noise_min = 1e-3
noise_max = 0.03
noise_steps = 50
noise_rel = torch.tensor(
    np.logspace(np.log10(noise_min), np.log10(noise_max),
                num=noise_steps)).float()
noise_rel = (torch.cat([
    torch.zeros(1).float(),
    noise_rel,
    0.060 * torch.ones(1).float(),
        Flatten(0, -3),
        Normalize(reduction="mean", use_target=True),
    ], ),
}
test_data = AlmostFixedMaskDataset
test_data = test_data("val", **test_data_params)

vols = range(30)
slices_in_vols = [test_data.get_slices_in_volume(vol_idx) for vol_idx in vols]
slices_selected = [
    range((lo + hi) // 2, (lo + hi) // 2 + 1) for lo, hi in slices_in_vols
]
samples = np.concatenate(slices_selected)

X_0 = torch.stack([test_data[s][2] for s in samples])
X_0 = to_complex(X_0.to(device))

Y_0 = torch.stack([test_data[s][0] for s in samples])
Y_0 = im2vec(to_complex(Y_0.to(device)))[...,
                                         im2vec(cfg_rob.mask[0,
                                                             0, :, :].bool())]

# no meas samples
it_init = 6
keep_init = 3

noise_type = noise_gaussian

# select range relative noise
noise_rel = torch.tensor([0.00, 0.0025, 0.005, 0.01, 0.015, 0.02, 0.025])
    True,
    "transform":
    torchvision.transforms.Compose([
        CropOrPadAndResimulate((320, 320)),
        Flatten(0, -3),
        Normalize(reduction="mean", use_target=True),
    ], ),
}
test_data = AlmostFixedMaskDataset
test_data = test_data("val", **test_data_params)

lo, hi = test_data.get_slices_in_volume(sample_vol)
print("volume slices from {} to {}, selected {}".format(
    lo, hi, lo + sample_sl))
X_VOL = to_complex(
    torch.stack([test_data[sl_idx][2] for sl_idx in range(lo, hi)],
                dim=0)).to(device)
X_MAX = rotate_real(X_VOL)[:, 0:1, ...].max().cpu()
X_0 = to_complex(test_data[lo + sample_sl][2].to(device)).unsqueeze(0)
X_0 = X_0.repeat(it_init, *((X_0.ndim - 1) * (1, )))
Y_0 = cfg_rob.OpA(X_0)

# set range for plotting and similarity indices
v_min = 0.05
v_max = 4.50
print("Pixel values between {} and {}".format(v_min, v_max))

# create result table and load existing results from file
results = pd.DataFrame(columns=[
    "name",
    "X_adv_err",
Exemplo n.º 6
0
# select measure for reconstruction error
err_measure = err_measure_l2

# select reconstruction methods
methods_include = ["L1", "UNet jit", "Tiramisu EE jit", "UNet It jit"]
methods = methods.loc[methods_include]

# select methods excluded from (re-)performing attacks
methods_no_calc = ["L1", "UNet jit", "Tiramisu EE jit", "UNet It jit"]

# ----- perform attack -----

# select sample
test_data = IPDataset("test", config.DATA_PATH)
X_0 = test_data[sample][0]
X_0 = to_complex(X_0.to(device)).unsqueeze(0)
X_0 = X_0.repeat(it, *((X_0.ndim - 1) * (1, )))
Y_0 = cfg_rob.OpA(X_0)

# create result table and load existing results from file
results = pd.DataFrame(columns=["name", "X_err", "X_psnr", "X_ssim", "X", "Y"])
results.name = methods.index
results = results.set_index("name")
# load existing results from file
if os.path.isfile(save_results):
    results_save = pd.read_pickle(save_results)
    for idx in results_save.index:
        if idx in results.index:
            results.loc[idx] = results_save.loc[idx]
else:
    results_save = results
Exemplo n.º 7
0
# select methods excluded from (re-)performing attacks
methods_no_calc = [
    "L1",
    "UNet jit",
    "UNet EE jit",
    "Tiramisu jit",
    "Tiramisu EE jit",
    "UNet It jit",
]

# ----- perform attack -----

# select samples
test_data = IPDataset("test", config.DATA_PATH)
X_0 = torch.stack([test_data[s][0] for s in samples])
X_0 = to_complex(X_0.to(device))
Y_0 = cfg_rob.OpA(X_0)

# create result table
results = pd.DataFrame(
    columns=[
        "name",
        "X_adv_err",
        "X_ref_err",
        "X_adv_psnr",
        "X_ref_psnr",
        "X_adv_ssim",
        "X_ref_ssim",
    ]
)
results.name = methods.index
Exemplo n.º 8
0
 def __call__(self, imgs):
     return tuple([to_complex(img) for img in imgs])
Exemplo n.º 9
0
def square_perturbation(shape):
    t_img = torch.zeros(shape)
    x, y = int(0.4675 * shape[0]), int(0.39 * shape[1])
    t_img[(x - 1):(x + 2), (y - 1):(y + 2)] = 1

    return to_complex(t_img.unsqueeze(0)).unsqueeze(0).to(device)