def generate(generator, data, device): input, _, mean, std, mask, _, _, _ = data input = input.to(device) mask = mask.to(device) output_network = generator(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) # Projection to consistent K-space output_consistent, target_kspace, output_kspace = project_to_consistent_subspace( output_network, input, mask) # Take loss on the cropped, real valued image (abs) mean = mean.to(device) std = std.to(device) mean = mean.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(device) std = std.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(device) output_consistent = transforms.unnormalize(output_consistent, mean, std) output_consistent = transforms.complex_center_crop(output_consistent, (320, 320)) output_consistent = transforms.complex_abs(output_consistent) output_network = transforms.unnormalize(output_network, mean, std) output_network = transforms.complex_center_crop(output_network, (320, 320)) output_network = transforms.complex_abs(output_network) return output_consistent, output_network, target_kspace, output_kspace
def train_step(model, data, device): input, target, mean, std, mean_image, std_image, mask = data input = input.to(device) mask = mask.to(device) target = target.to(device) output = model(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) # Projection to consistent K-space output = input * mask + (1-mask) * output # Consistent K-space loss (with the normalized output and target) loss_k_consistent = F.l1_loss(output, target) mean = mean.to(device) std = std.to(device) target = transforms.unnormalize(target, mean, std) output = transforms.unnormalize(output, mean, std) output_image = transforms.ifft2(output) target_image = transforms.ifft2(target) output_image = transforms.complex_center_crop(output_image, (320, 320)) output_image = transforms.complex_abs(output_image) target_image = transforms.complex_center_crop(target_image, (320, 320)) target_image = transforms.complex_abs(target_image) mean_image = mean_image.unsqueeze(1).unsqueeze(2).to(device) std_image = std_image.unsqueeze(1).unsqueeze(2).to(device) output_image = transforms.normalize(output_image, mean_image, std_image) target_image = transforms.normalize(target_image, mean_image, std_image) target_image = target_image.clamp(-6, 6) # Consistent image loss (with the unnormalized output and target) loss_image = F.l1_loss(output_image, target_image) loss = loss_k_consistent + loss_image return loss
def __call__(self, kspace, target, attrs, fname, slice): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. target (numpy.array): Target image attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name slice (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. norm (float): L2 norm of the entire volume. """ target_inference = transforms.to_tensor(target) kspace = transforms.to_tensor(kspace) target = transforms.ifft2(kspace) # Apply mask seed = None if not self.use_seed else tuple(map(ord, fname)) if self.use_mask: mask = transforms.get_mask(kspace, self.mask_func, seed) masked_kspace = mask * kspace else: masked_kspace = kspace image = transforms.ifft2(masked_kspace) image_crop = transforms.complex_center_crop( image, (self.resolution, self.resolution)) _, mean, std = transforms.normalize_instance_complex(image_crop, eps=1e-11) image_abs = transforms.complex_abs(image_crop) image_abs, mean_abs, std_abs = transforms.normalize_instance(image_abs, eps=1e-11) image = transforms.normalize(image, mean, std) target_image_complex_norm = transforms.normalize(target, mean, std) target_kspace_train = transforms.fft2(target_image_complex_norm) target = transforms.complex_center_crop(target, (320, 320)) target = transforms.complex_abs(target) target_train = target if RENORM: target_train = transforms.normalize(target_train, mean_abs, std_abs) if CLAMP: image = image.clamp(-6, 6) target_train = target_train.clamp(-6, 6) return image, target_train, target_kspace_train, mean, std, mask, mean_abs, std_abs, target_inference, attrs[ 'max'], attrs['norm'].astype(np.float32)
def __call__(self, kspace, target, attrs, fname, slice): """ Args: kspace (numpy.Array): k-space measurements target (numpy.Array): Target image attrs (dict): Acquisition related information stored in the HDF5 object fname (pathlib.Path): Path to the input file slice (int): Serial number of the slice Returns: (tuple): tuple containing: image (torch.Tensor): Normalized zero-filled input image mean (float): Mean of the zero-filled image std (float): Standard deviation of the zero-filled image fname (pathlib.Path): Path to the input file slice (int): Serial number of the slice """ kspace = transforms.to_tensor(kspace) if self.mask_func is not None: seed = tuple(map(ord, fname)) masked_kspace, _ = transforms.apply_mask(kspace, self.mask_func, seed) else: masked_kspace = kspace # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) # Crop input image image = transforms.complex_center_crop(image, (self.resolution, self.resolution)) # Absolute value image = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if self.which_challenge == 'multicoil': image = transforms.root_sum_of_squares(image) # Normalize input image, mean, std = transforms.normalize_instance(image) image = image.clamp(-6, 6) return image, mean, std, fname, slice
def test_complex_abs(shape): shape = shape + [2] input = create_input(shape) out_torch = transforms.complex_abs(input).numpy() input_numpy = utils.tensor_to_complex_np(input) out_numpy = np.abs(input_numpy) assert np.allclose(out_torch, out_numpy)
def __call__(self, kspace, target, challenge, fname, slice_index): original_kspace = transforms.to_tensor(kspace) if self.reduce: original_kspace = reducedimension(original_kspace, self.resolution) # Apply mask seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace, mask = transforms.apply_mask(original_kspace, self.mask_func, seed) # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) # Crop input image image = transforms.complex_center_crop( image, (self.resolution, self.resolution)) # Absolute value image = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if challenge == 'multicoil': image = transforms.root_sum_of_squares(image) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) target = transforms.to_tensor(target) # Normalize target target = transforms.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) if self.polar: original_kspace = cartesianToPolar(original_kspace) masked_kspace = cartesianToPolar(masked_kspace) return original_kspace, masked_kspace, mask, target, fname, slice_index
def train_step(model, data, device): input, target, mean, std, norm, _, mean_abs, std_abs = data input = input.to(device) target = target.to(device) output = model(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) mean = mean.to(device) std = std.to(device) if TRAIN_COMPLEX and not RENORM: output = transforms.unnormalize(output, mean, std) elif not TRAIN_COMPLEX: output = transforms.unnormalize(output, mean, std) output = transforms.complex_abs(output) if RENORM: mean_abs = mean_abs.unsqueeze(1).unsqueeze(2).to(device) std_abs = std_abs.unsqueeze(1).unsqueeze(2).to(device) output = transforms.normalize(output, mean_abs, std_abs) loss_f = F.smooth_l1_loss if SMOOTH else F.l1_loss loss = loss_f(output, target) if RENORM: return loss else: return 1e9 * loss
def resize(hparams, image, target): smallest_width = min(hparams.resolution, image.shape[-2]) smallest_height = min(hparams.resolution, image.shape[-3]) if target is not None: smallest_width = min(smallest_width, target.shape[-1]) smallest_height = min(smallest_height, target.shape[-2]) crop_size = (smallest_height, smallest_width) image = transforms.complex_center_crop(image, crop_size) # Absolute value image_abs = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if hparams.challenge == "multicoil": image_abs = transforms.root_sum_of_squares(image_abs) # Normalize input image_abs, mean, std = transforms.normalize_instance(image_abs, eps=1e-11) image_abs = image_abs.clamp(-6, 6) # Normalize target if target is not None: target = transforms.to_tensor(target) target = transforms.center_crop(target, crop_size) target = transforms.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) else: target = torch.Tensor([0]) return image, image_abs, target, mean, std
def forward(self, masked_kspace, mask): sens_maps = self.sens_net(masked_kspace, mask) kspace_pred = masked_kspace.clone() for cascade in self.cascades: kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps) return T.root_sum_of_squares(T.complex_abs(T.ifft2(kspace_pred)), dim=1)
def __call__(self, kspace, target, attrs, fname, slice): """ Args: kspace (numpy.Array): k-space measurements target (numpy.Array): Target image attrs (dict): Acquisition related information stored in the HDF5 object fname (pathlib.Path): Path to the input file slice (int): Serial number of the slice Returns: (tuple): tuple containing: image (torch.Tensor): Normalized zero-filled input image mean (float): Mean of the zero-filled image std (float): Standard deviation of the zero-filled image fname (pathlib.Path): Path to the input file slice (int): Serial number of the slice """ kspace = transforms.to_tensor(kspace) image = transforms.ifft2(kspace) image = transforms.complex_center_crop( image, (self.resolution, self.resolution)) # Apply Root-Sum-of-Squares if multicoil data if self.which_challenge == 'multicoil': image = transforms.root_sum_of_squares(image) image = transforms.complex_abs(image) image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) kspace = transforms.rfft2(image) return kspace, mean, std, fname, slice
def __call__(self, ksp, sens, mask, fname, slice): mask = torch.from_numpy(mask) mask = (torch.stack((mask, mask), dim=-1)).float() ksp_cmplx = ksp[:, :, ::2] + 1j * ksp[:, :, 1::2] sens_t = T.to_tensor(sens) ksp_t = T.to_tensor(ksp_cmplx) ksp_us = ksp_t.permute(2, 0, 1, 3) img_us = T.ifft2(ksp_us) img_us_sens = T.combine_all_coils(img_us, sens_t) pha_us = T.phase(img_us_sens) mag_us = T.complex_abs(img_us_sens) mag_us_pad = T.pad(mag_us, [256, 256]) pha_us_pad = T.pad(pha_us, [256, 256]) ksp_us_np = ksp ksp_us_np = ksp_us_np[:, :, ::2] + 1j * ksp_us_np[:, :, 1::2] img_us_np = T.zero_filled_reconstruction(ksp_us_np) return mag_us_pad / mag_us_pad.max( ), pha_us_pad, ksp_us / mag_us_pad.max( ), sens_t, mask, fname.name, slice, img_us_np.max()
def __call__(self, kspace, mask, target, attrs, fname, slice): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. mask (numpy.array): Mask from the test dataset target (numpy.array): Target image attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name slice (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. """ kspace = transforms.to_tensor(kspace) # Apply mask if self.mask_func: seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace, mask = transforms.apply_mask( kspace, self.mask_func, seed) else: masked_kspace = kspace # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) # Crop input image to given resolution if larger smallest_width = min(self.resolution, image.shape[-2]) smallest_height = min(self.resolution, image.shape[-3]) if target is not None: smallest_width = min(smallest_width, target.shape[-1]) smallest_height = min(smallest_height, target.shape[-2]) crop_size = (smallest_height, smallest_width) image = transforms.complex_center_crop(image, crop_size) # Absolute value image = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if self.which_challenge == 'multicoil': image = transforms.root_sum_of_squares(image) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) # Normalize target if target is not None: target = transforms.to_tensor(target) target = transforms.center_crop(target, crop_size) target = transforms.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) else: target = torch.Tensor([0]) return image, target, mean, std, fname, slice
def __call__(self, kspace, target, attrs, fname, slice): kspace_rect = transforms.to_tensor(kspace) ##rectangular kspace image_rect = transforms.ifft2(kspace_rect) ##rectangular FS image image_square = transforms.complex_center_crop( image_rect, (self.resolution, self.resolution)) ##cropped to FS square image kspace_square = self.c3object.apply( transforms.fft2(image_square)) #* 10000 ##kspace of square iamge if self.augmentation: kspace_square = self.augmentation.apply(kspace_square) image_square = ifft_c3(kspace_square) # Apply mask seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace_square, mask = transforms.apply_mask( kspace_square, self.mask_func, seed) ##ZF square kspace # Inverse Fourier Transform to get zero filled solution # image = transforms.ifft2(masked_kspace) image_square_us = ifft_c3( masked_kspace_square) ## US square complex image # Crop input image # image = transforms.complex_center_crop(image, (self.resolution, self.resolution)) # Absolute value # image = transforms.complex_abs(image) image_square_abs = transforms.complex_abs( image_square_us) ## US square real image # Apply Root-Sum-of-Squares if multicoil data # if self.which_challenge == 'multicoil': # image = transforms.root_sum_of_squares(image) # Normalize input # image, mean, std = transforms.normalize_instance(image, eps=1e-11) _, mean, std = transforms.normalize_instance(image_square_abs, eps=1e-11) # image = image.clamp(-6, 6) # target = transforms.to_tensor(target) target = image_square.permute(2, 0, 1) # Normalize target # target = transforms.normalize(target, mean, std, eps=1e-11) # target = target.clamp(-6, 6) # return image, target, mean, std, attrs['norm'].astype(np.float32) # return masked_kspace_square.permute((2,0,1)), image, image_square.permute(2,0,1), mean, std, attrs['norm'].astype(np.float32) # ksp, zf, target, me, st, nor return masked_kspace_square.permute((2,0,1)), image_square_us.permute((2,0,1)), \ target, \ mean, std, attrs['norm'].astype(np.float32)
def __call__(self, kspace, target, attrs, fname, slice): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. target (numpy.array): Target image attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name slice (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. norm (float): L2 norm of the entire volume. """ kspace = transforms.to_tensor(kspace) # Apply mask seed = None if not self.use_seed else tuple(map(ord, fname)) if self.use_mask: mask = transforms.get_mask(kspace, self.mask_func, seed) masked_kspace = mask * kspace else: masked_kspace = kspace # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) # Crop input image image = transforms.complex_center_crop( image, (self.resolution, self.resolution)) # Absolute value image = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if self.which_challenge == 'multicoil': image = transforms.root_sum_of_squares(image) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) if CLAMP: image = image.clamp(-6, 6) # Normalize target target = transforms.to_tensor(target) target_train = transforms.normalize(target, mean, std, eps=1e-11) if CLAMP: target_train = target_train.clamp( -6, 6) # Return target (for viz) and target_clamped (for training) return image, target_train, mean, std, attrs['norm'].astype( np.float32), target
def inference(model, data, device): input, _, mean, std, _, target, mean_abs, std_abs = data input = input.to(device) target = target.to(device) output = model(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) mean = mean.to(device) std = std.to(device) mean_abs = mean_abs.unsqueeze(1).unsqueeze(2).to(device) std_abs = std_abs.unsqueeze(1).unsqueeze(2).to(device) output = transforms.unnormalize(output, mean, std) output = transforms.complex_abs(output) return output, target
def __call__(self, kspace, target, attrs, fname, slice): """ Args: kspace (numpy.Array): k-space measurements target (numpy.Array): Target image attrs (dict): Acquisition related information stored in the HDF5 object fname (pathlib.Path): Path to the input file slice (int): Serial number of the slice Returns: (tuple): tuple containing: image (torch.Tensor): Normalized zero-filled input image mean (float): Mean of the zero-filled image std (float): Standard deviation of the zero-filled image fname (pathlib.Path): Path to the input file slice (int): Serial number of the slice """ kspace = transforms.to_tensor(kspace) if self.mask_func is not None: seed = tuple(map(ord, fname)) masked_kspace, _ = transforms.apply_mask(kspace, self.mask_func, seed) else: masked_kspace = kspace # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) # Crop input image image = transforms.complex_center_crop(image, (self.resolution, self.resolution)) # Absolute value image = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if self.which_challenge == 'multicoil': image = transforms.root_sum_of_squares(image) # Normalize input image, mean, std = transforms.normalize_instance(image) image = image.clamp(-6, 6) # difference between kspace actual and target dim extra = int(masked_kspace.shape[1] - self.kspace_x) # clip kspace at input dim if extra > 0: masked_kspace = masked_kspace[:, (extra//2):-(extra//2), :] # zero pad if necessary elif extra < 0: empty_kspace = torch.zeros((masked_kspace.shape[0], self.kspace_x, masked_kspace.shape[2])) empty_kspace[:, -(extra//2):(extra//2), :] = masked_kspace masked_kspace = empty_kspace #TODO return mask as well for exclusive updates return masked_kspace, image, mean, std, fname, slice
def inference(model, data, device): with torch.no_grad(): input, target, mean, std, _, _, mask = data input = input.to(device) mask = mask.to(device) output = model(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) output = input * mask + (1-mask) * output target = target.to(device) mean = mean.to(device) std = std.to(device) output = transforms.unnormalize(output, mean, std) target = transforms.unnormalize(target, mean, std) output = transforms.ifft2(output) target = transforms.ifft2(target) output = transforms.complex_center_crop(output, (320, 320)) output = transforms.complex_abs(output) target = transforms.complex_center_crop(target, (320, 320)) target = transforms.complex_abs(target) return output, target
def to_spatial(self, kspace, resolution): ''' k space: pytorch tensor post enchancement ''' # Inverse Fourier Transform to get interpolated solution image = transforms.ifft2(kspace) # Crop input image image = transforms.complex_center_crop(image, (resolution, resolution)) # Absolute value image = transforms.complex_abs(image) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) return image
def kspacetoimage(kspace, args): # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(kspace) # Crop input image image = transforms.complex_center_crop(image, (args.resolution, args.resolution)) # Absolute value image = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if args.challenge == 'multicoil': image = transforms.root_sum_of_squares(image) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) return image
def generate(generator, data, device): input, _, mean, std, mask, _, _, _ = data input = input.to(device) mask = mask.to(device) output_network = generator(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) # Take loss on the cropped, real valued image (abs) mean = mean.to(device) std = std.to(device) output_network = transforms.unnormalize(output_network, mean, std) output_network = transforms.complex_center_crop(output_network, (320, 320)) output_network = transforms.complex_abs(output_network) return output_network
def data_for_training(rawdata, sensitivity, mask_func, norm=True): ''' normalize each slice using complex absolute max value''' rawdata = T.to_tensor(np.complex64(rawdata.transpose(2, 0, 1))) sensitivity = T.to_tensor(sensitivity.transpose(2, 0, 1)) coils, Ny, Nx, ps = rawdata.shape # shift data shift_kspace = rawdata x, y = np.meshgrid(np.arange(1, Nx + 1), np.arange(1, Ny + 1)) adjust = (-1)**(x + y) shift_kspace = T.ifftshift(shift_kspace, dim=( -3, -2)) * torch.from_numpy(adjust).view(1, Ny, Nx, 1).float() # apply masks shape = np.array(shift_kspace.shape) shape[:-3] = 1 mask = mask_func(shape) mask = T.ifftshift(mask) # shift mask # undersample masked_kspace = torch.where(mask == 0, torch.Tensor([0]), shift_kspace) masks = mask.repeat(coils, Ny, 1, ps) img_gt, img_und = T.ifft2(shift_kspace), T.ifft2(masked_kspace) if norm: # perform k space raw data normalization # during inference there is no ground truth image so use the zero-filled recon to normalize norm = T.complex_abs(img_und).max() if norm < 1e-6: norm = 1e-6 # normalized recon else: norm = 1 # normalize data to learn more effectively img_gt, img_und = img_gt / norm, img_und / norm rawdata_und = masked_kspace / norm # faster sense_gt = cobmine_all_coils(img_gt, sensitivity) sense_und = cobmine_all_coils(img_und, sensitivity) return sense_und, sense_gt, rawdata_und, masks, sensitivity
def k_space_to_image_with_mask(kspace, mask_func=None, seed=None): #use_seed = False #seed = None if not use_seed else tuple(map(ord, fname)) #seed = 42 #print(fname) #kspace = transforms.to_tensor(kspace) if mask_func: masked_kspace, mask = transforms.apply_mask(kspace, mask_func, seed) # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) else: image = transforms.ifft2(kspace) image = transforms.complex_abs(image) image = transforms.center_crop(image, (320, 320)) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) return image
def nkspacetoimage(args, kspace_fni, mean, std, eps=1e-11): #nkspace to image assert kspace_fni.size(-1) == 2 image = transforms.ifftshift(kspace_fni, dim=(-3, -2)) image = torch.ifft(image, 2) image = transforms.fftshift(image, dim=(-3, -2)) #denormalizing the nimage image = (image * std) + mean image = image[0] image = transforms.complex_center_crop(image, (args.resolution, args.resolution)) # Absolute value image = transforms.complex_abs(image) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) return image
def eval(args, model, data_loader): model.eval() reconstructions = defaultdict(list) with torch.no_grad(): for (input, mean, std, fnames, slices) in data_loader: input = input.to(args.device) recons = model(input).to('cpu').squeeze(1) recons = transforms.complex_abs(recons) # complex to real for i in range(recons.shape[0]): recons[i] = recons[i] * std[i] + mean[i] reconstructions[fnames[i]].append( (slices[i].numpy(), recons[i].numpy())) reconstructions = { fname: np.stack([pred for _, pred in sorted(slice_preds)]) for fname, slice_preds in reconstructions.items() } return reconstructions
def kspaceto2dimage(kspace, polar, cropping=False, resolution=None): if polar: kspace = polarToCartesian(kspace) if cropping: if not resolution: raise Exception( "If cropping = True, pass the value for resolution for the function: kspaceto2dimage" ) image = croppedimage(kspace, resolution) else: image = transforms.ifft2(kspace) # Absolute value image = transforms.complex_abs(image) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) return image
def data_for_training(rawdata, sensitivity, mask, norm=True): ''' normalize each slice using complex absolute max value''' coils, Ny, Nx, ps = rawdata.shape # shift data shift_kspace = rawdata x, y = np.meshgrid(np.arange(1, Nx + 1), np.arange(1, Ny + 1)) adjust = (-1)**(x + y) shift_kspace = T.ifftshift(shift_kspace, dim=( -3, -2)) * torch.from_numpy(adjust).view(1, Ny, Nx, 1).float() #masked_kspace = torch.where(mask == 0, torch.Tensor([0]), shift_kspace) mask = T.ifftshift(mask) mask = mask.unsqueeze(0).unsqueeze(-1).float() mask = mask.repeat(coils, 1, 1, ps) masked_kspace = shift_kspace * mask img_gt, img_und = T.ifft2(shift_kspace), T.ifft2(masked_kspace) if norm: # perform k space raw data normalization # during inference there is no ground truth image so use the zero-filled recon to normalize norm = T.complex_abs(img_und).max() if norm < 1e-6: norm = 1e-6 # normalized recon else: norm = 1 # normalize data to learn more effectively img_gt, img_und = img_gt / norm, img_und / norm rawdata_und = masked_kspace / norm # faster sense_gt = cobmine_all_coils(img_gt, sensitivity) sense_und = cobmine_all_coils(img_und, sensitivity) sense_und_kspace = T.fft2(sense_und) return sense_und, sense_gt, sense_und_kspace, rawdata_und, mask, sensitivity
def forward(self, slice, kspace, mask): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # hidden state tensors for first iteration, initialized to 0 h1 = torch.zeros(1, 384, 80, 80) h2 = torch.zeros(1, 192, 160, 160) h3 = torch.zeros(1, 96, 320, 320) h1, h2, h3 = h1.to(device), h2.to(device), h3.to(device) out1 = slice for _ in range(5): out1, h1 = self.convrnn1(input=out1, hidden_input=h1, kspace=kspace, mask=mask) out2 = out1 for _ in range(5): out2, h2 = self.convrnn2(input=out2, hidden_input=h2, kspace=kspace, mask=mask) out3 = out2 for _ in range(5): out3, h3 = self.convrnn3(input=out3, hidden_input=h3, kspace=kspace, mask=mask) out = torch.cat((out1, out2, out3), 1) # concatenation of outputs of ConvRNN layers out = self.final(out, kspace, mask) # final block out = transforms.complex_abs( out) # transform complex image into real image; out is 320 x 320 out = torch.unsqueeze(out, 0) # adding dimension to have 1 x 1 x 320 x 320 return out
def generate(generator, data, device): input, _, mean, std, mask, _, _, _ = data input = input.to(device) mask = mask.to(device) # Use network to predict residual residual = generator(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) # Projection to consistent K-space if PROJECT: output = project_to_consistent_subspace(residual, input, mask) # Take loss on the cropped, real valued image (abs) mean = mean.to(device) std = std.to(device) output = transforms.unnormalize(output, mean, std) output = transforms.complex_center_crop(output, (320, 320)) output = transforms.complex_abs(output) return output
def save_zero_filled(data_dir, out_dir, which_challenge, resolution): reconstructions = {} for file in data_dir.iterdir(): print("file:{}".format(file)) with h5py.File(file, "r") as hf: masked_kspace = transforms.to_tensor(hf['kspace'][()]) # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) # Crop input image smallest_width = min(resolution, image.shape[-2]) smallest_height = min(resolution, image.shape[-3]) image = transforms.complex_center_crop(image, (smallest_height, smallest_width)) # Absolute value image = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if which_challenge == 'multicoil': image = transforms.root_sum_of_squares(image, dim=1) reconstructions[file.name] = image save_reconstructions(reconstructions, out_dir)
def test(epoch): model.eval() # test mode data_len = len(data_list['val']) for iteration, samples in enumerate(test_loader): print(' iteration {} out of {} in validation'.format(iteration, epoch)) img_und, img_gt, rawdata_und, masks, sensitivity = samples img_gt = torch.tensor(img_gt).to(device) img_und = torch.tensor(img_und).to(device) rawdata_und = torch.tensor(rawdata_und).to(device) masks = torch.tensor(masks).to(device) sensitivity = torch.tensor(sensitivity).to(device) rec = model(img_und, rawdata_und, masks, sensitivity) loss = mse(rec, img_gt) sense_recon = T.complex_abs(rec).data.to('cpu').numpy() sense_gt = T.complex_abs(img_gt).data.to('cpu').numpy() sense_und = T.complex_abs(img_und).data.to('cpu').numpy() if iteration % 5 == 0: A = sense_und[0]/(sense_und.max()) B = sense_recon[0]/(sense_recon.max()) C = sense_gt[0]/(sense_gt.max()) vis.image(np.clip(abs(np.c_[A, B, C, C - B]), 0, 1), win=test_image_window, opts=dict(title='test')) vis.line(X=np.array([iteration+epoch*data_len]), Y=np.array([loss.item()]), update='append', win=test_loss_window) vis.line(X=np.array([iteration+epoch*data_len]), Y=np.array([ssim(sense_gt[0], sense_recon[0])]), update='append', win=test_ssim_window) vis.line(X=np.array([iteration+epoch*data_len]), Y=np.array([psnr(sense_gt[0], sense_recon[0])]), update='append', win=test_psnr_window) vis.line(X=np.array([iteration+epoch*data_len]), Y=np.array([nmse(sense_gt[0], sense_recon[0])]), update='append', win=test_nmse_window) for idx in range(img_gt.shape[0]): base_psnr = psnr(abs(sense_gt[idx]), abs(sense_und[idx])) base_ssim = ssim(abs(sense_gt[idx]), abs(sense_und[idx])) base_nmse = nmse(abs(sense_gt[idx]), abs(sense_und[idx])) test_psnr = psnr(abs(sense_gt[idx]), abs(sense_recon[idx])) test_ssim = ssim(abs(sense_gt[idx]), abs(sense_recon[idx])) test_nmse = nmse(abs(sense_gt[idx]), abs(sense_recon[idx])) if idx == 0: val_log.write('{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}\n'. \ format(epoch, iteration, idx, loss.item(), base_psnr, \ test_psnr, base_ssim, test_ssim, base_nmse, test_nmse)) val_log.flush() else: val_log.write('{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}\n'. \ format(epoch, '', idx, '', base_psnr, \ test_psnr, base_ssim, test_ssim, base_nmse, test_nmse)) val_log.flush()