def forward(self, kspace_outputs, targets, extra_params): if kspace_outputs.size(0) > 1: raise NotImplementedError('Only one slice at a time for now.') kspace_targets = targets['kspace_targets'] # For removing width dimension padding. Recall that k-space form has 2 as last dim size. left = (kspace_outputs.size(-1) - kspace_targets.size(-2)) // 2 right = left + kspace_targets.size(-2) # Cropping width dimension by pad. kspace_recons = nchw_to_kspace(kspace_outputs[..., left:right]) assert kspace_recons.shape == kspace_targets.shape, 'Reconstruction and target sizes are different.' assert (kspace_recons.size(-3) % 2 == 0) and (kspace_recons.size(-2) % 2 == 0), \ 'Not impossible but not expected to have sides with odd lengths.' # Removing weighting. if self.weighted: weighting = extra_params['weightings'] kspace_recons = kspace_recons / weighting if self.replace: # Replace with original k-space if replace=True mask = extra_params['masks'] kspace_recons = kspace_recons * (1 - mask) + kspace_targets * mask cmg_recons = ifft2(kspace_recons) img_recons = complex_abs(cmg_recons) recons = { 'kspace_recons': kspace_recons, 'cmg_recons': cmg_recons, 'img_recons': img_recons } return recons # Returning scaled reconstructions. Not rescaled.
def forward(self, cmg_output, targets, extra_params): if cmg_output.size(0) > 1: raise NotImplementedError('Only one at a time for now.') cmg_target = targets['cmg_targets'] cmg_recon = nchw_to_kspace(cmg_output) assert cmg_recon.shape == cmg_target.shape, 'Reconstruction and target sizes are different.' assert (cmg_recon.size(-3) % 2 == 0) and (cmg_recon.size(-2) % 2 == 0), \ 'Not impossible but not expected to have sides with odd lengths.' if self.residual_acs: # Adding the semi-k-space of the ACS as a residual. Necessary due to complex cropping. raise NotImplementedError('Not ready yet.') # cmg_acs = targets['cmg_acss'] # cmg_recon = cmg_recon + cmg_acs kspace_recon = fft2(cmg_recon) img_recon = complex_abs(cmg_recon) recons = { 'kspace_recons': kspace_recon, 'cmg_recons': cmg_recon, 'img_recons': img_recon } if self.challenge == 'multicoil': rss_recon = center_crop(img_recon, ( self.resolution, self.resolution)) * extra_params['cmg_scales'] rss_recon = root_sum_of_squares(rss_recon, dim=1).squeeze() recons['rss_recons'] = rss_recon return recons # recons are not rescaled except rss_recons.
def forward(self, semi_kspace_outputs, targets, extra_params): if semi_kspace_outputs.size(0) > 1: raise NotImplementedError('Only one at a time for now.') semi_kspace_targets = targets['semi_kspace_targets'] # For removing width dimension padding. Recall that k-space form has 2 as last dim size. left = (semi_kspace_outputs.size(-1) - semi_kspace_targets.size(-2)) // 2 right = left + semi_kspace_targets.size(-2) # Cropping width dimension by pad. semi_kspace_recons = nchw_to_kspace(semi_kspace_outputs[..., left:right]) assert semi_kspace_recons.shape == semi_kspace_targets.shape, 'Reconstruction and target sizes are different.' assert (semi_kspace_recons.size(-3) % 2 == 0) and (semi_kspace_recons.size(-2) % 2 == 0), \ 'Not impossible but not expected to have sides with odd lengths.' # Removing weighting. if self.weighted: weighting = extra_params['weightings'] semi_kspace_recons = semi_kspace_recons / weighting if self.residual_acs: num_low_freqs = extra_params['num_low_frequency'] acs_mask = find_acs_mask(semi_kspace_recons, num_low_freqs) semi_kspace_recons = semi_kspace_recons + acs_mask * semi_kspace_targets if self.replace: mask = extra_params['masks'] semi_kspace_recons = semi_kspace_recons * ( 1 - mask) + semi_kspace_targets * mask kspace_recons = fft1(semi_kspace_recons, direction='height') cmg_recons = ifft1(semi_kspace_recons, direction='width') img_recons = complex_abs(cmg_recons) recons = { 'semi_kspace_recons': semi_kspace_recons, 'kspace_recons': kspace_recons, 'cmg_recons': cmg_recons, 'img_recons': img_recons } if self.challenge == 'multicoil': rss_recons = center_crop(img_recons, (self.resolution, self.resolution)) rss_recons = root_sum_of_squares(rss_recons, dim=1).squeeze() rss_recons *= extra_params[ 'sk_scales'] # This value was divided in the inputs. It is thus multiplied here. recons['rss_recons'] = rss_recons return recons # Returning scaled reconstructions. Not rescaled. RSS images are rescaled.
def forward(self, kspace_output, kspace_target, extra_params): k_scale, mask = extra_params # For removing width dimension padding. Recall that k-space form has 2 as last dim size. left = (kspace_output.size(-1) - kspace_target.size(-2)) // 2 right = left + kspace_target.size(-2) # Processing to k-space form. This is where the batch_size == 1 is important. # 1. Crop padding. 2. Reshape to kspace shape. 3. Unweight k-space values. 4. Rescale to original scale. kspace_recon = exp_weighting(nchw_to_kspace(kspace_output[..., left:right]), scale=self.log_amp_scale) * k_scale assert kspace_recon.size() == kspace_target.size(), 'Reconstruction and target sizes do not match.' kspace_recon = kspace_recon * (1 - mask) + kspace_target * mask return kspace_recon
def forward(self, tensor, out_shape): # Using out_shape only works for batch size of 1. """ Args: tensor (torch.Tensor): Input tensor of shape [batch_size, in_chans, height, width] out_shape (tuple): shape [batch_size, num_coils, true_height, true_width]. Note that in_chans = 2 * num_coils Returns: (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width] """ stack = list() output = tensor # Apply down-sampling layers for layer in self.down_sample_layers: output = layer(output) stack.append(output) output = F.max_pool2d(output, kernel_size=2) output = self.conv(output) # Apply up-sampling layers for layer in self.up_sample_layers: output = F.interpolate(output, scale_factor=2, mode='bilinear', align_corners=False) output = torch.cat((output, stack.pop()), dim=1) output = layer(output) output = self.conv2(output) # End of learning. # For removing width dimension padding. Recall that k-space form has 2 as last dim size. left = (output.size(-1) - out_shape[-1] ) // 2 # This depends on mini-batch size being 1 to work. right = left + out_shape[-1] # Previously, cropping was done by [pad:-pad]. However, this fails catastrophically when pad=0 as # all values are wiped out in those cases where [0:0] creates an empty tensor. # Cropping width dimension by pad. output = output[..., left:right] # Processing to k-space form. output = nchw_to_kspace(output) # Convert to image. output = complex_abs(ifft2(output)) assert output.size() == out_shape # Checking just in case. return output
def forward(self, kspace_output, c_img_target, extra_params): kspace_target, k_scale, mask = extra_params # For removing width dimension padding. Recall that k-space form has 2 as last dim size. left = (kspace_output.size(-1) - kspace_target.size(-2)) // 2 right = left + kspace_target.size(-2) # Cropping width dimension by pad. Multiply by scales to restore the original scaling. kspace_recon = exp_weighting(nchw_to_kspace(kspace_output[..., left:right]), scale=self.log_amp_scale) * k_scale assert kspace_recon.size() == kspace_target.size(), 'Reconstruction and target sizes do not match.' kspace_recon = kspace_recon * (1 - mask) + kspace_target * mask c_img_recons = ifft2(kspace_recon) return c_img_recons
def forward(self, kspace_outputs, targets, extra_params): if kspace_outputs.size(0) > 1: raise NotImplementedError('Only one slice at a time for now.') kspace_targets = targets['kspace_targets'] # For removing width dimension padding. Recall that k-space form has 2 as last dim size. left = (kspace_outputs.size(-1) - kspace_targets.size(-2)) // 2 right = left + kspace_targets.size(-2) # Cropping width dimension by pad. kspace_recons = nchw_to_kspace(kspace_outputs[..., left:right]) assert kspace_recons.shape == kspace_targets.shape, 'Reconstruction and target sizes are different.' assert (kspace_recons.size(-3) % 2 == 0) and (kspace_recons.size(-2) % 2 == 0), \ 'Not impossible but not expected to have sides with odd lengths.' # Removing weighting. if self.weighted: weighting = extra_params['weightings'] kspace_recons = kspace_recons / weighting if self.residual_acs: num_low_freqs = extra_params['num_low_frequency'] acs_mask = find_acs_mask(kspace_recons, num_low_freqs) kspace_recons = kspace_recons + acs_mask * kspace_targets if self.replace: # Replace with original k-space if replace=True mask = extra_params['masks'] kspace_recons = kspace_recons * (1 - mask) + kspace_targets * mask cmg_recons = ifft2(kspace_recons) img_recons = complex_abs(cmg_recons) recons = { 'kspace_recons': kspace_recons, 'cmg_recons': cmg_recons, 'img_recons': img_recons } if img_recons.size(1) == 15: top = (img_recons.size(-2) - self.resolution) // 2 left = (img_recons.size(-1) - self.resolution) // 2 rss_recon = img_recons[:, :, top:top + self.resolution, left:left + self.resolution] rss_recon = root_sum_of_squares( rss_recon, dim=1).squeeze() # rss_recon is in 2D recons['rss_recons'] = rss_recon return recons # Returning scaled reconstructions. Not rescaled.
def _cmg_output(cmg_output, targets, extra_params): cmg_target = targets['cmg_targets'] cmg_recon = nchw_to_kspace( cmg_output) # Assumes data was cropped already. assert cmg_recon.shape == cmg_target.shape, 'Reconstruction and target sizes are different.' assert (cmg_recon.size(-3) % 2 == 0) and (cmg_recon.size(-2) % 2 == 0), \ 'Not impossible but not expected to have sides with odd lengths.' cmg_recon = cmg_recon + targets[ 'cmg_inputs'] # Residual of complex input. kspace_recon = fft2(cmg_recon) img_recon = complex_abs(cmg_recon) recons = { 'kspace_recons': kspace_recon, 'cmg_recons': cmg_recon, 'img_recons': img_recon } return recons
def forward(self, kspace_output, kspace_target, extra_params): if not kspace_output.size(0) == 1: raise NotImplementedError('Only single batch for now.') scaling, mask = extra_params # For removing width dimension padding. Recall that k-space form has 2 as last dim size. left = (kspace_output.size(-1) - kspace_target.size(-2)) // 2 right = left + kspace_target.size(-2) # Cropping width dimension by pad. Multiply by scales to restore the original scaling. k_output = kspace_output[..., left:right] * scaling # Processing to k-space form. This is where the batch_size == 1 is important. kspace_recon = nchw_to_kspace(k_output) assert kspace_recon.size() == kspace_target.size(), 'Reconstruction and target sizes do not match.' return kspace_recon
def forward(self, kspace_output, c_img_target, extra_params): kspace_target, scaling, mask = extra_params # For removing width dimension padding. Recall that k-space form has 2 as last dim size. left = (kspace_output.size(-1) - kspace_target.size(-2)) // 2 right = left + kspace_target.size(-2) # Cropping width dimension by pad. Multiply by scales to restore the original scaling. k_output = kspace_output[..., left:right] * scaling # Processing to k-space form. This is where the batch_size == 1 is important. kspace_recon = nchw_to_kspace(k_output) assert kspace_recon.size() == kspace_target.size(), 'Reconstruction and target sizes do not match.' kspace_recon = kspace_recon * (1 - mask) + kspace_target * mask c_img_recons = ifft2(kspace_recon) return c_img_recons
def forward(self, semi_kspace_outputs, targets, extra_params): if semi_kspace_outputs.size(0) > 1: raise NotImplementedError('Only one batch at a time for now.') semi_kspace_targets = targets['semi_kspace_targets'] # For removing width dimension padding. Recall that k-space form has 2 as last dim size. left = (semi_kspace_outputs.size(-1) - semi_kspace_targets.size(-2)) // 2 right = left + semi_kspace_targets.size(-2) # Cropping width dimension by pad. semi_kspace_recons = nchw_to_kspace(semi_kspace_outputs[..., left:right]) assert semi_kspace_recons.shape == semi_kspace_targets.shape, 'Reconstruction and target sizes are different.' assert (semi_kspace_recons.size(-3) % 2 == 0) and (semi_kspace_recons.size(-2) % 2 == 0), \ 'Not impossible but not expected to have sides with odd lengths.' # Removing weighting. if self.weighted: weighting = extra_params['weightings'] semi_kspace_recons = semi_kspace_recons / weighting if self.replace: mask = extra_params['masks'] semi_kspace_recons = semi_kspace_recons * ( 1 - mask) + semi_kspace_targets * mask kspace_recons = fft1(semi_kspace_recons, direction=self.direction) cmg_recons = ifft1(semi_kspace_recons, direction=self.recon_direction) img_recons = complex_abs(cmg_recons) recons = { 'semi_kspace_recons': semi_kspace_recons, 'kspace_recons': kspace_recons, 'cmg_recons': cmg_recons, 'img_recons': img_recons } return recons # Returning scaled reconstructions. Not rescaled.
def forward(self, semi_kspace_outputs, targets, extra_params): if semi_kspace_outputs.size(0) > 1: raise NotImplementedError('Only one at a time for now.') semi_kspace_recons = nchw_to_kspace(semi_kspace_outputs) semi_kspace_targets = targets['semi_kspace_targets'] assert semi_kspace_recons.shape == semi_kspace_targets.shape, 'Reconstruction and target sizes are different.' assert (semi_kspace_recons.size(-3) % 2 == 0) and (semi_kspace_recons.size(-2) % 2 == 0), \ 'Not impossible but not expected to have sides with odd lengths.' # Removing weighting. if self.weighted: weighting = extra_params['weightings'] semi_kspace_recons = semi_kspace_recons / weighting if self.residual_acs: # Adding the semi-k-space of the ACS as a residual. Necessary due to complex cropping. semi_kspace_acs = targets['semi_kspace_acss'] semi_kspace_recons = semi_kspace_recons + semi_kspace_acs kspace_recons = fft1(semi_kspace_recons, direction='height') cmg_recons = ifft1(semi_kspace_recons, direction='width') img_recons = complex_abs(cmg_recons) recons = { 'semi_kspace_recons': semi_kspace_recons, 'kspace_recons': kspace_recons, 'cmg_recons': cmg_recons, 'img_recons': img_recons } if self.challenge == 'multicoil': rss_recons = root_sum_of_squares(img_recons, dim=1).squeeze() rss_recons *= extra_params['sk_scales'] recons['rss_recons'] = rss_recons return recons # Returning scaled reconstructions. Not rescaled. RSS images are rescaled.
def forward(self, k_output, targets, scales): """ Output post-processing for output k-space tensor with batch size of 1. This is experimental and is subject to change. Planning on taking k-space outputs from CNNs, then transforming them into original k-space shapes. No batch size planned yet. Args: k_output (torch.Tensor): CNN output of k-space. Expected to have batch-size of 1. targets (torch.Tensor): Target image domain data. scales (torch.Tensor): scaling factor used to divide the input k-space slice data. Returns: kspace (torch.Tensor): kspace in original shape with batch dimension in-place. """ # For removing width dimension padding. Recall that k-space form has 2 as last dim size. left = (k_output.size(-1) - targets.size(-1) ) // 2 # This depends on mini-batch size being 1 to work. right = left + targets.size(-1) # Previously, cropping was done by [pad:-pad]. However, this fails catastrophically when pad == 0 as # all values are wiped out in those cases where [0:0] creates an empty tensor. # Cropping width dimension by pad. Multiply by scales to restore the original scaling. k_output = k_output[..., left:right] * scales # Processing to k-space form. This is where the batch_size == 1 is important. kspace_recons = nchw_to_kspace(k_output) # Convert to image. image_recons = complex_abs(ifft2(kspace_recons)) assert image_recons.size() == targets.size( ), 'Reconstruction and target sizes do not match.' return image_recons, kspace_recons
def restore_orig_shape(k_slice, target_slice): left_pad = (k_slice.size(-1) - target_slice.size(-1)) // 2 right_pad = (1 + k_slice.size(-1) - target_slice.size(-1)) // 2 k_slice = k_slice[..., left_pad:-right_pad] return nchw_to_kspace(k_slice)
import numpy as np import torch from time import time k1 = np.random.uniform(size=(32, 15, 640, 328)) k2 = np.random.uniform(size=(32, 15, 640, 328)) k = k1 + k2 * 1j kt = to_tensor(k) tic = time() ncwh = kspace_to_nchw(kt) mid = kt.shape[1] for idx, kts in enumerate(kt): temp = k_slice_to_chw(kts) print(idx, torch.eq(ncwh[idx], temp).all()) chan = 17 ri = chan % 2 sli = chan // 2 print( torch.eq(torch.squeeze(ncwh[3, chan, ...]), torch.squeeze(kt[3, sli, ..., ri])).all()) kspace = nchw_to_kspace(ncwh) toc = time() - tic print(torch.eq(kt, kspace).all(), toc)