Exemplo n.º 1
0
    def __call__(self, inputs):
        kspace, mask, target = inputs

        # pad if necessary
        p1 = max(0, self.shape[0] - target.shape[-2])
        p2 = max(0, self.shape[1] - target.shape[-1])
        target_padded = torch.nn.functional.pad(
            target, (p2 // 2, -(-p2 // 2), p1 // 2, -(-p1 // 2)),
        )

        # crop if necessary
        target_cropped = transforms.center_crop(target_padded, self.shape)

        # resimulate
        kspace_cropped = transforms.fft2(prep_fft_channel(target_cropped))
        new_mask = mask(kspace_cropped.shape).expand_as(kspace_cropped)
        new_kspace = unprep_fft_channel(kspace_cropped * new_mask)
        new_mask = unprep_fft_channel(new_mask)

        tcs = target_cropped.shape[-3]
        if not tcs == 2:
            target_cropped = target_cropped[
                ..., ((tcs // 2) // 2) * 2 : ((tcs // 2) // 2) * 2 + 2, :, :
            ]

        return new_kspace, new_mask, target_cropped
Exemplo n.º 2
0
 def _process_data(self, kspace_data, gt_data):
     mask = self.mask_func(kspace_data.shape).expand_as(kspace_data)
     kspace_data = unprep_fft_channel(kspace_data.unsqueeze(0))
     mask = unprep_fft_channel(mask.unsqueeze(0))
     kspace_masked = kspace_data * mask
     gt_data = to_complex(gt_data)
     if self.keep_mask_as_func:
         mask = self.mask_func
     return kspace_masked, mask, gt_data
Exemplo n.º 3
0
    def _process_data(self, kspace_data, gt_data):
        mask = self.mask_func(kspace_data.shape, seed=self.seed).expand_as(
            kspace_data
        )
        kspace_data = unprep_fft_channel(kspace_data.unsqueeze(0))
        mask = unprep_fft_channel(mask.unsqueeze(0))
        kspace_masked = kspace_data * mask
        if self.simulate_gt:
            gt_data = unprep_fft_channel(gt_data.unsqueeze(0))
        if self.keep_mask_as_func:

            def mask(shape):
                return self.mask_func(shape, seed=self.seed)

        return kspace_masked, mask, gt_data
Exemplo n.º 4
0
 def __call__(self, inputs):
     if self.use_target:
         tar = inputs[-1]
     else:
         tar = unprep_fft_channel(
             transforms.ifft2(prep_fft_channel(inputs[0])))
     norm = torch.norm(tar, p=self.p)
     if self.reduction == "mean" and not self.p == "inf":
         norm /= np.prod(tar.shape)**(1 / self.p)
     if len(inputs) == 2:
         return inputs[0] / norm, inputs[1] / norm
     else:
         return inputs[0] / norm, inputs[1], inputs[2] / norm
)
from reconstruction_methods import admm_l1_rec_diag, grid_search

# ----- load configuration -----
import config  # isort:skip

# ------ setup ----------
device = torch.device("cuda:0")

file_name = "grid_search_l1_fourier_"
save_path = os.path.join(config.RESULTS_PATH, "grid_search_l1")

# ----- operators --------
n = (320, 320)
mask_func = RadialMaskFunc(n, 50)
mask = unprep_fft_channel(mask_func((1, 1) + n + (1, )))
OpA = Fourier(mask)
OpTV = TVAnalysisPeriodic(n, device=device)

# ----- load test data --------
test_data_params = {
    "mask_func":
    mask_func,
    "seed":
    1,
    "filter": [filter_acquisition_no_fs],
    "num_sym_slices":
    0,
    "multi_slice_gt":
    False,
    "keep_mask_as_func":
Exemplo n.º 6
0
    RadialMaskFunc,
    rotate_real,
    unprep_fft_channel,
)

# ----- load configuration -----
import config  # isort:skip

# ----- global configuration -----
mpl.use("agg")
device = torch.device("cuda:0")
torch.cuda.set_device(0)

# ----- measurement configuration -----
mask_func = RadialMaskFunc(config.n, 40)
mask = unprep_fft_channel(mask_func((1, 1) + config.n + (1, )))
OpA = Fourier(mask)
inverter = LearnableInverter(config.n, mask, learnable=False)

# ----- network configuration -----
subnet_params = {
    "in_channels": 2,
    "out_channels": 2,
    "drop_factor": 0.0,
    "down_blocks": (5, 7, 9, 12, 15),
    "up_blocks": (15, 12, 9, 7, 5),
    "pool_factors": (2, 2, 2, 2, 2),
    "bottleneck_layers": 20,
    "growth_rate": 16,
    "out_chans_first_conv": 16,
}
Exemplo n.º 7
0
 def __call__(self, inputs):
     kspace, mask, target = inputs
     inv = unprep_fft_channel(transforms.ifft2(prep_fft_channel(kspace)))
     return inv, target
Exemplo n.º 8
0
from reconstruction_methods import admm_l1_rec_diag


# ------ setup ----------
device = torch.device("cuda:0")
torch.cuda.set_device(0)
mask_seed = 123

# ----- operators -----
n = (368, 368)
mask_func = subsample.RandomMaskFunc(
    center_fractions=[0.08], accelerations=[4]
)
mask = unprep_fft_channel(
    mask_func((1, 1) + n + (1,), seed=mask_seed).expand_as(
        torch.ones((1, 1) + n + (1,))
    )
)
OpA = Fourier(mask)
OpTV = TVAnalysisPeriodic(n, device=device)

# ----- methods --------
methods = pd.DataFrame(columns=["name", "info", "reconstr", "attacker", "net"])
methods = methods.set_index("name")

noise_ref = noise_gaussian

# ----- set up L1 --------
# grid search parameters for L1 via admm
grid_search_file = os.path.join(
    config.RESULTS_PATH, "grid_search_l1", "grid_search_l1_fourier_all.pkl"