Exemplo n.º 1
0
    def __call__(self, inputs):
        kspace, mask, target = inputs

        # pad if necessary
        p1 = max(0, self.shape[0] - target.shape[-2])
        p2 = max(0, self.shape[1] - target.shape[-1])
        target_padded = torch.nn.functional.pad(
            target, (p2 // 2, -(-p2 // 2), p1 // 2, -(-p1 // 2)),
        )

        # crop if necessary
        target_cropped = transforms.center_crop(target_padded, self.shape)

        # resimulate
        kspace_cropped = transforms.fft2(prep_fft_channel(target_cropped))
        new_mask = mask(kspace_cropped.shape).expand_as(kspace_cropped)
        new_kspace = unprep_fft_channel(kspace_cropped * new_mask)
        new_mask = unprep_fft_channel(new_mask)

        tcs = target_cropped.shape[-3]
        if not tcs == 2:
            target_cropped = target_cropped[
                ..., ((tcs // 2) // 2) * 2 : ((tcs // 2) // 2) * 2 + 2, :, :
            ]

        return new_kspace, new_mask, target_cropped
Exemplo n.º 2
0
 def __call__(self, inputs):
     if self.use_target:
         tar = inputs[-1]
     else:
         tar = unprep_fft_channel(
             transforms.ifft2(prep_fft_channel(inputs[0])))
     norm = torch.norm(tar, p=self.p)
     if self.reduction == "mean" and not self.p == "inf":
         norm /= np.prod(tar.shape)**(1 / self.p)
     if len(inputs) == 2:
         return inputs[0] / norm, inputs[1] / norm
     else:
         return inputs[0] / norm, inputs[1], inputs[2] / norm
Exemplo n.º 3
0
 def __call__(self, inputs):
     kspace, mask, target = inputs
     inv = unprep_fft_channel(transforms.ifft2(prep_fft_channel(kspace)))
     return inv, target