def soft_dc(self, x, input): if self.space == 'img-space': x = T.fft2(sens_expand(x, input['sens_maps'])) x = torch.where(input['mask'], x - input['kspace'], self.zero) if self.space == 'img-space': x = sens_reduce(T.ifft2(x), input['sens_maps']) return self.lambda_ * x
def forward(self, x): # if torch.any(torch.isnan(x)): # pdb.set_trace() y = T.fft2(x) # if torch.any(torch.isnan(y)): # pdb.set_trace() return y
def net_forward(self, x, input): if self.space == 'k-space': x = sens_reduce(T.ifft2(x), input['sens_maps']) x = merge_multi_slice(x, cat_dim=-2).unsqueeze(1).contiguous() x = self.net(x) x = unmerge_multi_slice(x, 2).contiguous() if self.space == 'k-space': x = T.fft2(sens_expand(x, input['sens_maps'])) return x
def test_fft2(shape): shape = shape + [2] input = create_input(shape) out_torch = transforms.fft2(input).numpy() out_torch = out_torch[..., 0] + 1j * out_torch[..., 1] input_numpy = utils.tensor_to_complex_np(input) input_numpy = np.fft.ifftshift(input_numpy, (-2, -1)) out_numpy = np.fft.fft2(input_numpy, norm='ortho') out_numpy = np.fft.fftshift(out_numpy, (-2, -1)) assert np.allclose(out_torch, out_numpy)
def soft_dc(self, x, input): # if torch.any(torch.isnan(x)): # pdb.set_trace() if self.space == 'img-space': x = T.fft2(sens_expand(x, input['sens_maps'])) # if torch.any(torch.isnan(x)): # pdb.set_trace() x = torch.where(input['mask'], x - input['kspace'], self.zero) if self.space == 'img-space': x = sens_reduce(T.ifft2(x), input['sens_maps']) # if torch.any(torch.isnan(x)): # pdb.set_trace() return self.lambda_ * x
def net_forward(self, x, input): xinitial = x # if torch.any(torch.isnan(x)): # pdb.set_trace() if self.space == 'k-space': x = sens_reduce(T.ifft2(x), input['sens_maps']) # if torch.any(torch.isnan(x)): # pdb.set_trace() x = merge_multi_slice(x, cat_dim=-2).unsqueeze(1).contiguous() x = self.net(x) # if torch.any(torch.isnan(x)): # pdb.set_trace() x = unmerge_multi_slice(x, 2).contiguous() if self.space == 'k-space': x = T.fft2(sens_expand(x, input['sens_maps'])) # if torch.any(torch.isnan(x)): # pdb.set_trace() return x
def forward(self, x): return T.fft2(x)