def _visualize_outputs(c_img_recons, c_img_targets, smoothing_factor=8): image_recons = complex_abs(c_img_recons) image_targets = complex_abs(c_img_targets) kspace_recons = make_k_grid(fft2(c_img_recons), smoothing_factor) kspace_targets = make_k_grid(fft2(c_img_targets), smoothing_factor) image_recons, image_targets, image_deltas = make_grid_triplet( image_recons, image_targets) return kspace_recons, kspace_targets, image_recons, image_targets, image_deltas
def forward(self, cmg_output, targets, extra_params): if cmg_output.size(0) > 1: raise NotImplementedError('Only one at a time for now.') cmg_target = targets['cmg_targets'] cmg_recon = nchw_to_kspace(cmg_output) assert cmg_recon.shape == cmg_target.shape, 'Reconstruction and target sizes are different.' assert (cmg_recon.size(-3) % 2 == 0) and (cmg_recon.size(-2) % 2 == 0), \ 'Not impossible but not expected to have sides with odd lengths.' if self.residual_acs: # Adding the semi-k-space of the ACS as a residual. Necessary due to complex cropping. raise NotImplementedError('Not ready yet.') # cmg_acs = targets['cmg_acss'] # cmg_recon = cmg_recon + cmg_acs kspace_recon = fft2(cmg_recon) img_recon = complex_abs(cmg_recon) recons = { 'kspace_recons': kspace_recon, 'cmg_recons': cmg_recon, 'img_recons': img_recon } if self.challenge == 'multicoil': rss_recon = center_crop(img_recon, ( self.resolution, self.resolution)) * extra_params['cmg_scales'] rss_recon = root_sum_of_squares(rss_recon, dim=1).squeeze() recons['rss_recons'] = rss_recon return recons # recons are not rescaled except rss_recons.
def forward(self, outputs, targets, extra_params): img_recon, phase_recon = outputs if img_recon.size(0) > 1: raise NotImplementedError('Only one at a time for now.') img_target = targets['img_targets'] img_recon = F.relu(img_recon) # img_recons must be positive numbers. assert img_recon.shape == img_target.shape, 'Reconstruction and target sizes are different.' # Input transform had addition of pi as pre-processing. phase_recon = phase_recon - math.pi # No clamping implemented since the loss is MSE. cmg_recon = torch.stack([ img_recon * torch.cos(phase_recon), img_recon * torch.sin(phase_recon) ], dim=-1) kspace_recon = fft2(cmg_recon) recons = { 'img_recons': img_recon, 'phase_recons': phase_recon, 'cmg_recons': cmg_recon, 'kspace_recons': kspace_recon } if self.challenge == 'multicoil': rss_recon = center_crop( img_recon, shape=(self.resolution, self.resolution)) * extra_params['img_scales'] rss_recon = root_sum_of_squares(rss_recon, dim=1).squeeze() recons['rss_recons'] = rss_recon return recons # recons are not rescaled except rss_recons.
def test_fft2(shape): shape = shape + [2] tensor = create_tensor(shape) out_torch = data_transforms.fft2(tensor).numpy() out_torch = out_torch[..., 0] + 1j * out_torch[..., 1] tensor_numpy = data_transforms.tensor_to_complex_np(tensor) tensor_numpy = np.fft.ifftshift(tensor_numpy, (-2, -1)) out_numpy = np.fft.fft2(tensor_numpy, norm='ortho') out_numpy = np.fft.fftshift(out_numpy, (-2, -1)) assert np.allclose(out_torch, out_numpy)
def _cmg_output(cmg_output, targets, extra_params): cmg_target = targets['cmg_targets'] cmg_recon = nchw_to_kspace( cmg_output) # Assumes data was cropped already. assert cmg_recon.shape == cmg_target.shape, 'Reconstruction and target sizes are different.' assert (cmg_recon.size(-3) % 2 == 0) and (cmg_recon.size(-2) % 2 == 0), \ 'Not impossible but not expected to have sides with odd lengths.' cmg_recon = cmg_recon + targets[ 'cmg_inputs'] # Residual of complex input. kspace_recon = fft2(cmg_recon) img_recon = complex_abs(cmg_recon) recons = { 'kspace_recons': kspace_recon, 'cmg_recons': cmg_recon, 'img_recons': img_recon } return recons
def forward(self, cmg_output: Tensor, targets: dict, extra_params: dict): assert cmg_output.dim() == 5 and cmg_output.size( 1) == 2, 'Invalid shape!' if cmg_output.size(0) > 1: raise NotImplementedError('Only one at a time for now.') kspace_target = targets['kspace_targets'] cmg_recon = cmg_output.permute(dims=(0, 2, 3, 4, 1)) # Convert back into NCHW2 if cmg_recon.shape != kspace_target.shape: # Cropping recon left-right. left = (cmg_recon.size(-2) - kspace_target.size(-2)) // 2 cmg_recon = cmg_recon[..., left:left + kspace_target.size(-2), :] assert cmg_recon.shape == kspace_target.shape, 'Reconstruction and target sizes are different.' assert (cmg_recon.size(-3) % 2 == 0) and (cmg_recon.size(-2) % 2 == 0), \ 'Not impossible but not expected to have sides with odd lengths.' kspace_recon = fft2(cmg_recon) if self.replace_kspace: mask = extra_params['masks'] kspace_recon = kspace_target * mask + (1 - mask) * kspace_recon cmg_recon = ifft2(kspace_recon) img_recon = complex_abs(cmg_recon) # recons = {'kspace_recons': kspace_recon, 'cmg_recons': cmg_recon, 'img_recons': img_recon} recons = dict() if self.challenge == 'multicoil': rss_recon = center_crop(img_recon, ( self.resolution, self.resolution)) * extra_params['cmg_scales'] rss_recon = root_sum_of_squares(rss_recon, dim=1).squeeze() recons['rss_recons'] = rss_recon return recons # recons are not rescaled except rss_recons.
import torch import numpy as np from data.data_transforms import ifft2, fft2 a = torch.rand(20, 40, 60, 92, 2, device='cuda:1') b = ifft2(a * 1E8) c = fft2(b) * 1E-8 # print(torch.all(a == c)) # print(torch.allclose(a, c, rtol=0.01)) eps = np.finfo(np.float64).eps print(torch.max(c / (a + eps))) print(torch.min(c / (a + eps))) print(torch.mean(c / (a + eps)).cpu().numpy()) # print(torch.sum(a != c) / (20 * 40 * 60 * 92 * 2))
def check_invertible(): orig = torch.rand(4, 6, 8, 12, 2, dtype=torch.float64) * 1024 - 64 trans = fft2(orig) * 100 trans = ifft2(trans) / 100 print(torch.allclose(orig, trans))
import torch from data.data_transforms import ifft2, fft2, complex_abs image = torch.rand(10, 20, 30, 2) lr_flip = torch.flip(image, dims=[-2]) ud_flip = torch.flip(image, dims=[-3]) all_flip = torch.flip(image, dims=[-3, -2]) kspace = fft2(image) lr_kspace = fft2(lr_flip) ud_kspace = fft2(ud_flip) all_kspace = fft2(all_flip) absolute = torch.sum(complex_abs(kspace)) lr_abs = torch.sum(complex_abs(lr_kspace)) ud_abs = torch.sum(complex_abs(ud_kspace)) all_abs = torch.sum(complex_abs(all_kspace)) a = torch.allclose(absolute, lr_abs) b = torch.allclose(absolute, ud_abs) c = torch.allclose(absolute, all_abs) print(a, b, c)
def __call__(self, kspace_target, target, attrs, file_name, slice_num): assert isinstance( kspace_target, torch.Tensor ), 'k-space target was expected to be a Pytorch Tensor.' if kspace_target.dim( ) == 3: # If the collate function does not expand dimensions for single-coil. kspace_target = kspace_target.expand(1, 1, -1, -1, -1) elif kspace_target.dim( ) == 4: # If the collate function does not expand dimensions for multi-coil. kspace_target = kspace_target.expand(1, -1, -1, -1, -1) elif kspace_target.dim( ) != 5: # Expanded k-space should have 5 dimensions. raise RuntimeError('k-space target has invalid shape!') if kspace_target.size(0) != 1: raise NotImplementedError('Batch size should be 1 for now.') with torch.no_grad(): seed = None if not self.use_seed else tuple(map(ord, file_name)) masked_kspace, mask, info = apply_info_mask( kspace_target, self.mask_func, seed) num_low_freqs = info['num_low_frequency'] acs_mask = self.find_acs_mask(kspace_target, num_low_freqs) acs_kspace = kspace_target * acs_mask semi_kspace_acs = fft1(complex_center_crop( ifft2(acs_kspace), shape=(self.resolution, self.resolution)), direction='width') complex_image = ifft2(masked_kspace) complex_image = complex_center_crop(complex_image, shape=(self.resolution, self.resolution)) # img_input is not actually an input but what the input would look like in the image domain. img_input = complex_abs(complex_image) # Direction is fixed due to challenge conditions. semi_kspace = fft1(complex_image, direction='width') weighting = self.weight_func(semi_kspace) semi_kspace *= weighting sk_scale = torch.std(semi_kspace) semi_kspace /= sk_scale inputs = kspace_to_nchw(semi_kspace) extra_params = { 'sk_scales': sk_scale, 'masks': mask, 'weightings': weighting } extra_params.update(info) extra_params.update(attrs) # Recall that the Fourier transform is a linear transform. cmg_target = ifft2(kspace_target) cmg_target = complex_center_crop(cmg_target, shape=(self.resolution, self.resolution)) cmg_target /= sk_scale img_target = complex_abs(cmg_target) semi_kspace_target = fft1(cmg_target, direction='width') kspace_target = fft2(cmg_target) # Use plurals as keys to reduce confusion. targets = { 'semi_kspace_targets': semi_kspace_target, 'kspace_targets': kspace_target, 'cmg_targets': cmg_target, 'img_targets': img_target, 'img_inputs': img_input, 'semi_kspace_acss': semi_kspace_acs } if self.challenge == 'multicoil': targets['rss_targets'] = target return inputs, targets, extra_params
def __call__(self, kspace_target, target, attrs, file_name, slice_num): assert isinstance( kspace_target, torch.Tensor ), 'k-space target was expected to be a Pytorch Tensor.' if kspace_target.dim( ) == 3: # If the collate function does not expand dimensions for single-coil. kspace_target = kspace_target.expand(1, 1, -1, -1, -1) elif kspace_target.dim( ) == 4: # If the collate function does not expand dimensions for multi-coil. kspace_target = kspace_target.expand(1, -1, -1, -1, -1) elif kspace_target.dim( ) != 5: # Expanded k-space should have 5 dimensions. raise RuntimeError('k-space target has invalid shape!') if kspace_target.size(0) != 1: raise NotImplementedError('Batch size should be 1 for now.') with torch.no_grad(): # Apply mask seed = None if not self.use_seed else tuple(map(ord, file_name)) masked_kspace, mask, info = apply_info_mask( kspace_target, self.mask_func, seed) # Complex image made from down-sampled k-space. complex_image = ifft2(masked_kspace) if self.crop_center: complex_image = complex_center_crop(complex_image, shape=(self.resolution, self.resolution)) cmg_scale = torch.std(complex_image) complex_image /= cmg_scale extra_params = {'cmg_scales': cmg_scale, 'masks': mask} extra_params.update(info) extra_params.update(attrs) # Recall that the Fourier transform is a linear transform. kspace_target /= cmg_scale cmg_target = ifft2(kspace_target) if self.crop_center: cmg_target = complex_center_crop(cmg_target, shape=(self.resolution, self.resolution)) # Data augmentation by flipping images up-down and left-right. if self.augment_data: # No rotation implemented. flip_lr = torch.rand(()) < 0.5 flip_ud = torch.rand(()) < 0.5 if flip_lr and flip_ud: # Last dim is real/complex dimension for complex image and target. complex_image = torch.flip(complex_image, dims=(-3, -2)) cmg_target = torch.flip(cmg_target, dims=(-3, -2)) target = torch.flip(target, dims=( -2, -1)) # Has only two dimensions, height and width. kspace_target = fft2(cmg_target) elif flip_ud: complex_image = torch.flip(complex_image, dims=(-3, )) cmg_target = torch.flip(cmg_target, dims=(-3, )) target = torch.flip(target, dims=(-2, )) kspace_target = fft2(cmg_target) elif flip_lr: complex_image = torch.flip(complex_image, dims=(-2, )) cmg_target = torch.flip(cmg_target, dims=(-2, )) target = torch.flip(target, dims=(-1, )) kspace_target = fft2(cmg_target) # The image target is obtained after flipping the complex image. # This removes the need to flip the image target. img_target = complex_abs(cmg_target) img_inputs = complex_abs(complex_image) # Use plurals as keys to reduce confusion. targets = { 'kspace_targets': kspace_target, 'cmg_targets': cmg_target, 'img_targets': img_target, 'cmg_inputs': complex_image, 'img_inputs': img_inputs } if self.challenge == 'multicoil': targets['rss_targets'] = target # Creating concatenated image of real/imag/abs channels. concat_image = torch.cat( [complex_image, img_inputs.unsqueeze(dim=-1)], dim=-1) # Converting to NCHW format for CNN. inputs = kspace_to_nchw(concat_image) return inputs, targets, extra_params
def __call__(self, kspace_target, target, attrs, file_name, slice_num): assert isinstance( kspace_target, torch.Tensor ), 'k-space target was expected to be a Pytorch Tensor.' if kspace_target.dim( ) == 3: # If the collate function does not expand dimensions for single-coil. kspace_target = kspace_target.expand(1, 1, -1, -1, -1) elif kspace_target.dim( ) == 4: # If the collate function does not expand dimensions for multi-coil. kspace_target = kspace_target.expand(1, -1, -1, -1, -1) elif kspace_target.dim( ) != 5: # Expanded k-space should have 5 dimensions. raise RuntimeError('k-space target has invalid shape!') if kspace_target.size(0) != 1: raise NotImplementedError('Batch size should be 1 for now.') with torch.no_grad(): # Apply mask seed = None if not self.use_seed else tuple(map(ord, file_name)) masked_kspace, mask, info = apply_info_mask( kspace_target, self.mask_func, seed) complex_image = ifft2(masked_kspace) cmg_target = ifft2(kspace_target) if self.crop_center: complex_image = complex_center_crop(complex_image, shape=(self.resolution, self.resolution)) cmg_target = complex_center_crop(cmg_target, shape=(self.resolution, self.resolution)) # Data augmentation by flipping images up-down and left-right. if self.augment_data: flip_lr = torch.rand(()) < 0.5 flip_ud = torch.rand(()) < 0.5 if flip_lr and flip_ud: # Last dim is real/complex dimension for complex image and target. complex_image = torch.flip(complex_image, dims=(-3, -2)) cmg_target = torch.flip(cmg_target, dims=(-3, -2)) target = torch.flip(target, dims=( -2, -1)) # Has only two dimensions, height and width. elif flip_ud: complex_image = torch.flip(complex_image, dims=(-3, )) cmg_target = torch.flip(cmg_target, dims=(-3, )) target = torch.flip(target, dims=(-2, )) elif flip_lr: complex_image = torch.flip(complex_image, dims=(-2, )) cmg_target = torch.flip(cmg_target, dims=(-2, )) target = torch.flip(target, dims=(-1, )) # Adding pi to angles so that the phase is in the [0, 2pi] range for better learning. phase_input = torch.atan2(complex_image[..., 1], complex_image[..., 0]) phase_input += math.pi # Don't forget to remove the pi in the output transform! img_input = complex_abs(complex_image) img_scale = torch.std(img_input) img_input /= img_scale cmg_target /= img_scale img_target = complex_abs(cmg_target) kspace_target = fft2( cmg_target ) # Reconstruct k-space target after cropping and image augmentation. phase_target = torch.atan2(cmg_target[..., 1], cmg_target[..., 0]) extra_params = {'img_scales': img_scale, 'masks': mask} extra_params.update(info) extra_params.update(attrs) # Use plurals as keys to reduce confusion. targets = { 'kspace_targets': kspace_target, 'cmg_targets': cmg_target, 'img_targets': img_target, 'phase_targets': phase_target, 'img_inputs': img_input } if self.challenge == 'multicoil': targets['rss_targets'] = target # Converting to NCHW format for CNN. Also adding phase input. inputs = (img_input, phase_input) return inputs, targets, extra_params