def train_epoch(model, optimizer, train_loader): model.train() avg_loss = 0. start_epoch = start_iter = time.perf_counter() for iteration, sample in enumerate(train_loader): img_gt, img_und, rawdata_und, masks, norm = sample # Extract input and ground truth image input_img = T.complex_abs(img_und).squeeze() input_img = T.center_crop(input_img, [320, 320]) input_img = input_img[None,None,:,:].cuda() target_img = T.complex_abs(img_gt).squeeze() target_img = T.center_crop(target_img, [320, 320]) target_img = target_img[None,None,:,:].cuda() output = model(input_img) loss = F.l1_loss(output, target_img) #loss = ssim(target_img, output) optimizer.zero_grad() loss.backward() optimizer.step() avg_loss = 0.99 * avg_loss + 0.01 * loss.item() if iteration > 0 else loss.item() return avg_loss, time.perf_counter() - start_epoch
def get_epoch_batch(subject_id, acc, center_fract, is_undersampled, use_seed=True): ''' random select a few slices (batch_size) from each volume''' fname, rawdata_name, slice = subject_id with h5py.File(rawdata_name, 'r') as data: if(is_undersampled): rawdata = data['kspace_8af'][slice] else: rawdata = data['kspace'][slice] slice_kspace = T.to_tensor(rawdata).unsqueeze(0) S, Ny, Nx, ps = slice_kspace.shape # apply random mask shape = np.array(slice_kspace.shape) mask_func = MaskFunc(center_fractions=[center_fract], accelerations=[acc]) seed = None if not use_seed else tuple(map(ord, fname)) mask = mask_func(shape, seed) # undersample masked_kspace = torch.where(mask == 0, torch.Tensor([0]), slice_kspace) masks = mask.repeat(S, Ny, 1, ps) img_gt, img_und = T.ifft2(slice_kspace), T.ifft2(masked_kspace) # perform data normalization which is important for network to learn useful features # during inference there is no ground truth image so use the zero-filled recon to normalize norm = T.complex_abs(img_und).max() if norm < 1e-6: norm = 1e-6 # normalized data img_gt, img_und, rawdata_und = img_gt/norm, img_und/norm, masked_kspace/norm return img_gt.squeeze(0), img_und.squeeze(0), rawdata_und.squeeze(0), masks.squeeze(0), norm
def get_epoch_batch(subject_id, acc, center_fract, use_seed=True): ''' randomly select a few slices (batch_size) from each volume''' fname, rawdata_name, slice = subject_id with h5py.File(rawdata_name, 'r') as data: rawdata = data['kspace'][slice] slice_kspace = T.to_tensor(rawdata).unsqueeze(0) S, Ny, Nx, ps = slice_kspace.shape # apply random mask shape = np.array(slice_kspace.shape) mask_func = MaskFunc(center_fractions=[center_fract], accelerations=[acc]) seed = None if not use_seed else tuple(map(ord, fname)) mask = mask_func(shape, seed) # undersample masked_kspace = torch.where(mask == 0, torch.Tensor([0]), slice_kspace) masks = mask.repeat(S, Ny, 1, ps) # normalise data img_gt, img_und = T.ifft2(slice_kspace), T.ifft2(masked_kspace) norm = T.complex_abs(img_und).max() if norm < 1e-6: norm = 1e-6 img_gt, img_und, rawdata_und = img_gt / norm, img_und / norm, masked_kspace / norm return img_gt.squeeze(0), img_und.squeeze(0), rawdata_und.squeeze( 0), masks.squeeze(0), norm
def val_epoch(model, test_loader): #Setting model to evaluation mode model.eval() start_epoch = start_iter = time.perf_counter() losses = [] for iteration, sample in enumerate(test_loader): img_gt, img_und, rawdata_und, masks, norm = sample # Extract input and ground truth image input_img = T.complex_abs(img_und).squeeze() input_img = T.center_crop(input_img, [320, 320]) input_img = input_img[None,None,:,:].cuda() target_img = T.complex_abs(img_gt).squeeze() target_img = T.center_crop(target_img, [320, 320]) target_img = target_img[None,None,:,:].cuda() output = model(input_img) loss = F.l1_loss(output, target_img) losses.append(loss.item()) return np.mean(losses), time.perf_counter() - start_epoch
def eval_model(model): #Testing the trained model on test set - 20% data_path = '/data/local/NC2019MRI/train/' all_files = sorted(os.listdir(data_path)) #Using 20% of data as test set all_files = all_files[round(len(all_files)*0.8):] #acc and cent_fract should be (8, 0.04) or (4, 0.08) acc = 8 cen_fract = 0.04 # random masks for each slice seed = False # data loading is faster using a bigger number for num_workers. 0 means using one cpu to load data num_workers = 12 all_losses = [] for file in all_files: # first load all file names, paths and slices. data_list = load_test_data_path(data_path, file) #Create a dataloader containing all slices for particular slice test_dataset = MRIDataset(data_list['test'], acceleration=acc, center_fraction=cen_fract, use_seed=seed,is_undersampled=False) test_loader = DataLoader(test_dataset, shuffle=False, batch_size=1, num_workers=num_workers) #Call model over all slices. pred_vol is a list of volume slices (in correct order) pred_vol,test_time = test_epoch(model, test_loader) #3d Volume achieved by stacking slices pred = torch.stack(pred_vol, dim=0) #Loading Ground Truth with h5py.File(data_path + file, "r") as hf: volume_kspace = hf['kspace'][()] #Transform k-space to real valued image volume_kspace2 = T.to_tensor(volume_kspace) # Apply Inverse Fourier Transform to get the complex image volume_image = T.ifft2(volume_kspace2) volume_image_abs = T.complex_abs(volume_image) volume_image_abs = volume_image_abs[5:, :, :] real_gt = T.center_crop(volume_image_abs, [320, 320]) #visualize predicted and gt slices show_slices(pred, [5, 10, 20, pred.shape[0]-1], cmap='gray') show_slices(real_gt, [5, 10, 20, pred.shape[0]-1], cmap='gray') #calculate ssim score between volumes ssim_score = ssim(real_gt.numpy(), pred.numpy()) all_losses.append(ssim_score) average_ssim = sum(all_losses) / len(all_losses) return average_ssim
def test_epoch(model, test_loader): model.eval() start_epoch = start_iter = time.perf_counter() losses = [] output_vol = [] gt_vol = [] for iteration, sample in enumerate(test_loader): img_gt, img_und, rawdata_und, masks, norm = sample # Extract input and ground truth image input_img = T.complex_abs(img_und).squeeze() input_img = T.center_crop(input_img, [320, 320]) input_img = input_img[None,None,:,:].cuda() output = model(input_img) output = output[0,0,:,:].cpu().detach() output_vol.append(output*norm) #return np.mean(losses), time.perf_counter() - start_epoch return output_vol, time.perf_counter() - start_epoch
fig = plt.figure(figsize=(15,10)) for i, num in enumerate(slice_nums): plt.subplot(1, len(slice_nums), i + 1) plt.imshow(data[num], cmap=cmap) plt.axis('off') plt.show() show_slices(np.log(np.abs(volume_kspace) + 1e-9), [0, 10, 20, 30], cmap='gray') # Show slices in the sample as real images from functions import transforms as T volume_kspace2 = T.to_tensor(volume_kspace) volume_image = T.ifft2(volume_kspace2) volume_image_abs = T.complex_abs(volume_image) show_slices(volume_image_abs, [0, 10, 20, 30], cmap='gray') ## ## Simulate under-sample data ## import torch from functions.subsample import MaskFunc # Initiate MaskFunc() for 2 AF respectively mask_func0 = MaskFunc(center_fractions=[0.08], accelerations=[4]) mask_func1 = MaskFunc(center_fractions=[0.04], accelerations=[8]) # See subsample.py for detailed annotation for class MaskFunc().
acc = 8 cen_fract = 0.04 seed = False # random masks for each slice num_workers = 12 # data loading workers # create data loader for training set/ validation set train_dataset = MRIDataset(data_list['train'], acceleration=acc, center_fraction=cen_fract, use_seed=seed) train_loader = DataLoader(train_dataset, shuffle=True, batch_size=1, num_workers=num_workers) for iteration, sample in enumerate(train_loader): img_gt, img_und, rawdata_und, masks, norm = sample # stack different slices into a volume for visualisation A = masks[..., 0].squeeze() B = torch.log(T.complex_abs(rawdata_und) + 1e-9).squeeze() C = T.complex_abs(img_und).squeeze() D = T.complex_abs(img_gt).squeeze() all_imgs = torch.stack([A, B, C, D], dim=0) # mask, masked space, undersampled image, ground truth show_slices(all_imgs, [0, 1, 2, 3], cmap='gray') plt.pause(1) if iteration >= 3: break # to show 4 random slices