Example #1
0
 def soft_dc(self, x, input):
     if self.space == 'img-space':
         x = T.fft2(sens_expand(x, input['sens_maps']))
     x = torch.where(input['mask'], x - input['kspace'], self.zero)
     if self.space == 'img-space':
         x = sens_reduce(T.ifft2(x), input['sens_maps'])
     return self.lambda_ * x
Example #2
0
 def forward(self, x):
     # if torch.any(torch.isnan(x)):
     #     pdb.set_trace()
     y = T.fft2(x)
     # if torch.any(torch.isnan(y)):
     #     pdb.set_trace()
     return y
Example #3
0
    def net_forward(self, x, input):
        if self.space == 'k-space':
            x = sens_reduce(T.ifft2(x), input['sens_maps'])

        x = merge_multi_slice(x, cat_dim=-2).unsqueeze(1).contiguous()
        x = self.net(x)
        x = unmerge_multi_slice(x, 2).contiguous()

        if self.space == 'k-space':
            x = T.fft2(sens_expand(x, input['sens_maps']))
        return x
Example #4
0
def test_fft2(shape):
    shape = shape + [2]
    input = create_input(shape)
    out_torch = transforms.fft2(input).numpy()
    out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]

    input_numpy = utils.tensor_to_complex_np(input)
    input_numpy = np.fft.ifftshift(input_numpy, (-2, -1))
    out_numpy = np.fft.fft2(input_numpy, norm='ortho')
    out_numpy = np.fft.fftshift(out_numpy, (-2, -1))
    assert np.allclose(out_torch, out_numpy)
Example #5
0
 def soft_dc(self, x, input):
     # if torch.any(torch.isnan(x)):
     #     pdb.set_trace()
     if self.space == 'img-space':
         x = T.fft2(sens_expand(x, input['sens_maps']))
     # if torch.any(torch.isnan(x)):
     #     pdb.set_trace()
     x = torch.where(input['mask'], x - input['kspace'], self.zero)
     if self.space == 'img-space':
         x = sens_reduce(T.ifft2(x), input['sens_maps'])
     # if torch.any(torch.isnan(x)):
     #     pdb.set_trace()
     return self.lambda_ * x
Example #6
0
    def net_forward(self, x, input):
        xinitial = x
        # if torch.any(torch.isnan(x)):
        #     pdb.set_trace()
        if self.space == 'k-space':
            x = sens_reduce(T.ifft2(x), input['sens_maps'])
        # if torch.any(torch.isnan(x)):
        #     pdb.set_trace()
        x = merge_multi_slice(x, cat_dim=-2).unsqueeze(1).contiguous()
        x = self.net(x)
        # if torch.any(torch.isnan(x)):
        #     pdb.set_trace()
        x = unmerge_multi_slice(x, 2).contiguous()

        if self.space == 'k-space':
            x = T.fft2(sens_expand(x, input['sens_maps']))
        # if torch.any(torch.isnan(x)):
        #     pdb.set_trace()
        return x
Example #7
0
 def forward(self, x):
     return T.fft2(x)