def visualize(args, epoch, model, data_loader, writer): def save_image(image, tag): image -= image.min() image /= image.max() grid = torchvision.utils.make_grid(image, nrow=4, pad_value=1) writer.add_image(tag, grid, epoch) model.eval() with torch.no_grad(): for iter, data in enumerate(data_loader): stacked_kspace_square, _, stacked_image_square, _, _ = data # ksp,input, target, mean, std, norm = data # input = input.unsqueeze(1).to(args.device) # target = target.to(args.device) # target = target.unsqueeze(1) # print("target",target.shape) ksp_shifted_mc = transforms.ifftshift(stacked_kspace_square, dim=(-2, -1)) output = model(ksp_shifted_mc.cuda()) # print("output",output.shape) output_rss = stack_to_rss(output) # output_rss = output_rss.unsqueeze(1) target = stack_to_rss(stacked_image_square) #.unsqueeze(1) # print("output_rss",output_rss.shape) #FIXME - not ready yet save_image(target, 'Target') save_image(output_rss, 'Reconstruction') save_image((target - output_rss.cpu()), 'Error') break
def perform(self, out_img_cmplx, ksp, sens, mask): x = T.complex_multiply(out_img_cmplx[..., 0].unsqueeze(1), out_img_cmplx[..., 1].unsqueeze(1), sens[..., 0], sens[..., 1]) k = (torch.fft(x, 2, normalized=True)).squeeze(1) k_shift = T.ifftshift(k, dim=(-3, -2)) sr = 0.85 Nz = k_shift.shape[-2] Nz_sampled = int(np.ceil(Nz * sr)) k_shift[:, :, :, Nz_sampled:, :] = 0 v = self.noise_lvl if v is not None: # noisy case # out = (1 - mask) * k + mask * (k + v * k0) / (1 + v) out = (1 - mask) * k_shift + mask * (v * k_shift + (1 - v) * ksp) else: out = (1 - mask) * k_shift + mask * ksp x = torch.ifft(out, 2, normalized=True) Sx = T.complex_multiply(x[..., 0], x[..., 1], sens[..., 0], -sens[..., 1]).sum(dim=1) Ss = T.complex_multiply(sens[..., 0], sens[..., 1], sens[..., 0], -sens[..., 1]).sum(dim=1) return Sx, Ss
def evaluate(args, epoch, model, data_loader, writer): model.eval() losses = [] start = time.perf_counter() with torch.no_grad(): for iter, data in enumerate(tqdm(data_loader)): stacked_kspace_square, _, stacked_image_square, _, _ = data # input = input.unsqueeze(1).to(args.device) ksp_shifted_mc = transforms.ifftshift(stacked_kspace_square, dim=(-2, -1)) output = model(ksp_shifted_mc.cuda()) #.squeeze(1) # mean = mean.unsqueeze(1).unsqueeze(2).to(args.device) # std = std.unsqueeze(1).unsqueeze(2).to(args.device) # to be done only on magnitude # target = target * std + mean # output = output * std + mean # norm = norm.unsqueeze(1).unsqueeze(2).to(args.device) # norm = 1 # can't divide directly with complex # loss = F.mse_loss(output , target / norm, size_average=False) loss = F.mse_loss(output, stacked_image_square.cuda(), reduction='sum') losses.append(loss.item()) if writer is not None: writer.add_scalar('Dev_Loss', np.mean(losses), epoch) return np.mean(losses), time.perf_counter() - start
def data_for_training(rawdata, sensitivity, mask_func, norm=True): ''' normalize each slice using complex absolute max value''' rawdata = T.to_tensor(np.complex64(rawdata.transpose(2, 0, 1))) sensitivity = T.to_tensor(sensitivity.transpose(2, 0, 1)) coils, Ny, Nx, ps = rawdata.shape # shift data shift_kspace = rawdata x, y = np.meshgrid(np.arange(1, Nx + 1), np.arange(1, Ny + 1)) adjust = (-1)**(x + y) shift_kspace = T.ifftshift(shift_kspace, dim=( -3, -2)) * torch.from_numpy(adjust).view(1, Ny, Nx, 1).float() # apply masks shape = np.array(shift_kspace.shape) shape[:-3] = 1 mask = mask_func(shape) mask = T.ifftshift(mask) # shift mask # undersample masked_kspace = torch.where(mask == 0, torch.Tensor([0]), shift_kspace) masks = mask.repeat(coils, Ny, 1, ps) img_gt, img_und = T.ifft2(shift_kspace), T.ifft2(masked_kspace) if norm: # perform k space raw data normalization # during inference there is no ground truth image so use the zero-filled recon to normalize norm = T.complex_abs(img_und).max() if norm < 1e-6: norm = 1e-6 # normalized recon else: norm = 1 # normalize data to learn more effectively img_gt, img_und = img_gt / norm, img_und / norm rawdata_und = masked_kspace / norm # faster sense_gt = cobmine_all_coils(img_gt, sensitivity) sense_und = cobmine_all_coils(img_und, sensitivity) return sense_und, sense_gt, rawdata_und, masks, sensitivity
def data_for_training(rawdata, sensitivity, mask, norm=True): ''' normalize each slice using complex absolute max value''' coils, Ny, Nx, ps = rawdata.shape # shift data shift_kspace = rawdata x, y = np.meshgrid(np.arange(1, Nx + 1), np.arange(1, Ny + 1)) adjust = (-1)**(x + y) shift_kspace = T.ifftshift(shift_kspace, dim=( -3, -2)) * torch.from_numpy(adjust).view(1, Ny, Nx, 1).float() #masked_kspace = torch.where(mask == 0, torch.Tensor([0]), shift_kspace) mask = T.ifftshift(mask) mask = mask.unsqueeze(0).unsqueeze(-1).float() mask = mask.repeat(coils, 1, 1, ps) masked_kspace = shift_kspace * mask img_gt, img_und = T.ifft2(shift_kspace), T.ifft2(masked_kspace) if norm: # perform k space raw data normalization # during inference there is no ground truth image so use the zero-filled recon to normalize norm = T.complex_abs(img_und).max() if norm < 1e-6: norm = 1e-6 # normalized recon else: norm = 1 # normalize data to learn more effectively img_gt, img_und = img_gt / norm, img_und / norm rawdata_und = masked_kspace / norm # faster sense_gt = cobmine_all_coils(img_gt, sensitivity) sense_und = cobmine_all_coils(img_und, sensitivity) sense_und_kspace = T.fft2(sense_und) return sense_und, sense_gt, sense_und_kspace, rawdata_und, mask, sensitivity
def imagenormalize(data, divisor=None): """kspace generated by normalizing image space""" #getting image from masked data image = transforms.ifft2(data) #normalizing the image nimage, divisor = normalize(image, divisor) #getting kspace data from normalized image data = transforms.ifftshift(image, dim=(-3, -2)) data = torch.fft(data, 2) data = transforms.fftshift(data, dim=(-3, -2)) return data, divisor
def mnormalize(masked_kspace): #getting image from masked data image = transforms.ifft2(masked_kspace) #normalizing the image nimage, mean, std = transforms.normalize_instance(image, eps=1e-11) #getting kspace data from normalized image maksed_kspace_fni = transforms.ifftshift(nimage, dim=(-3, -2)) maksed_kspace_fni = torch.fft(maksed_kspace_fni, 2) maksed_kspace_fni = transforms.fftshift(maksed_kspace_fni, dim=(-3, -2)) maksed_kspace_fni, mean, std = transforms.normalize_instance(masked_kspace, eps=1e-11) return maksed_kspace_fni, mean, std
def onormalize(original_kspace, mean, std, eps=1e-11): #getting image from masked data image = transforms.ifft2(original_kspace) #normalizing the image nimage = transforms.normalize(image, mean, std, eps=1e-11) #getting kspace data from normalized image original_kspace_fni = transforms.ifftshift(nimage, dim=(-3, -2)) original_kspace_fni = torch.fft(original_kspace_fni, 2) original_kspace_fni = transforms.fftshift(original_kspace_fni, dim=(-3, -2)) original_kspace_fni = transforms.normalize(original_kspace, mean, std, eps=1e-11) return original_kspace_fni
def nkspacetoimage(args, kspace_fni, mean, std, eps=1e-11): #nkspace to image assert kspace_fni.size(-1) == 2 image = transforms.ifftshift(kspace_fni, dim=(-3, -2)) image = torch.ifft(image, 2) image = transforms.fftshift(image, dim=(-3, -2)) #denormalizing the nimage image = (image * std) + mean image = image[0] image = transforms.complex_center_crop(image, (args.resolution, args.resolution)) # Absolute value image = transforms.complex_abs(image) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) return image
def train_epoch(args, epoch, model, data_loader, optimizer, writer): model.train() avg_loss = 0. start_epoch = start_iter = time.perf_counter() global_step = epoch * len(data_loader) for iter, data in (enumerate(tqdm(data_loader))): stacked_kspace_square, _, stacked_image_square, _, _ = data # input = input.unsqueeze(1).to(args.device) # target = target.to(args.device) ksp_shifted_mc = transforms.ifftshift(stacked_kspace_square, dim=(-2, -1)) output = model(ksp_shifted_mc.cuda()) #.squeeze(1) # print("output",output.shape) # out_chans = stack_to_chans(output) # out_ksp = transforms.fft2(out_chans) # print("out_chans",out_ksp.shape) loss = F.l1_loss(output, stacked_image_square.cuda(), reduction='sum') optimizer.zero_grad() loss.backward() optimizer.step() avg_loss = 0.99 * avg_loss + 0.01 * loss.item( ) if iter > 0 else loss.item() if writer is not None: writer.add_scalar('TrainLoss', loss.item(), global_step + iter) if iter % args.report_interval == 0: logging.info( f'Epoch = [{epoch:3d}/{args.num_epochs:3d}] ' f'Iter = [{iter:4d}/{len(data_loader):4d}] ' f'Loss = {loss.item():.4g} Avg Loss = {avg_loss:.4g} ' f'Time = {time.perf_counter() - start_iter:.4f}s', ) start_iter = time.perf_counter() return avg_loss, time.perf_counter() - start_epoch
self.translation()).float() class C3Convert: def __init__(self, shp=(320, 320)): self.c3m = transforms.c3_torch(shp) # ksp.shape[-2:] def apply(self, ksp): # expect (bat, w, h, 2) return ksp * self.c3m # torch fft requires (bat, w,h, 2) ifft_c3 = lambda kspc3: torch.ifft( transforms.ifftshift(kspc3, dim=(-3, -2)), 2, normalized=True) fft_c3 = lambda im: transforms.fftshift(torch.fft(im, 2, normalized=True), dim=(-3, -2)) class NoTransform: def __init__(self): pass def __call__(self, kspace, target, attrs, fname, slice): return kspace class BasicMaskingTransform: def __init__(self,
def test_ifftshift(shape): input = np.arange(np.product(shape)).reshape(shape) out_torch = transforms.ifftshift(torch.from_numpy(input)).numpy() out_numpy = np.fft.ifftshift(input) assert np.allclose(out_torch, out_numpy)
from data import transforms as T def c3_multiplier_npy(shape=(320,320)): shp = (shape[0],shape[1]) mul_mat=np.resize([1,-1],shp) return mul_mat * mul_mat.T def c3_torch(shp): c3m = c3_multiplier_npy(shp) return torch.from_numpy(np.dstack((c3m,c3m))).float() ifft_c3 = lambda kspc3: torch.ifft(T.ifftshift(kspc3,dim=(-3,-2)),2,normalized=True) fft_c3 = lambda im: T.fftshift(torch.fft(im,2,normalized=True),dim=(-3,-2)) shp = (360,360) c3m = c3_torch(shp) def tosquare(ksp,shp): rec = T.ifft2(ksp) sz = rec.shape return c3m * T.fft2(T.complex_center_crop(rec,shp)) * 100000
def __call__(self, kspace, target, attrs, fname, slice): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. target (numpy.array): Target image attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name slice (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. norm (float): L2 norm of the entire volume. """ kspace = transforms.to_tensor(kspace) gt = transforms.ifft2(kspace) gt = transforms.complex_center_crop(gt, (self.resolution, self.resolution)) kspace = transforms.fft2(gt) # Apply mask seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace, mask = transforms.apply_mask(kspace, self.mask_func, seed) # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) masked_kspace = transforms.fft2_nshift(image) # Crop input image image = transforms.complex_center_crop(image, (self.resolution, self.resolution)) # Absolute value image_mod = transforms.complex_abs(image).max() image_r = image[:, :, 0]*6.0/image_mod image_i = image[:, :, 1]*6.0/image_mod # image_r = image[:, :, 0] # image_i = image[:, :, 1] # Apply Root-Sum-of-Squares if multicoil data if self.which_challenge == 'multicoil': image = transforms.root_sum_of_squares(image) # Normalize input image = np.stack((image_r, image_i), axis=-1) image = image.transpose((2, 0, 1)) image = transforms.to_tensor(image) target = transforms.ifft2(kspace) target = transforms.complex_center_crop(target, (self.resolution, self.resolution)) # Normalize target target_r = target[:, :, 0]*6.0/image_mod target_i = target[:, :, 1]*6.0/image_mod # target_r = target[:, :, 0] # target_i = target[:, :, 1] target = np.stack((target_r, target_i), axis=-1) target = target.transpose((2, 0, 1)) target = transforms.to_tensor(target) image_mod = np.stack((image_mod, image_mod), axis=0) image_mod = transforms.to_tensor(image_mod) norm = attrs['norm'].astype(np.float32) norm = np.stack((norm, norm), axis=-1) norm = transforms.to_tensor(norm) mask = mask.expand(kspace.shape) mask = mask.transpose(0, 2).transpose(1, 2) mask = transforms.ifftshift(mask) masked_kspace = masked_kspace.transpose(0, 2).transpose(1, 2) return image, target