def test_varnet_num_sense_lines(shape, chans, center_fractions, accelerations, mask_center): mask_func = RandomMaskFunc(center_fractions, accelerations) x = create_input(shape) output, mask, num_low_freqs = transforms.apply_mask(x, mask_func, seed=123) varnet = VarNet( num_cascades=2, sens_chans=4, sens_pools=2, chans=chans, pools=2, mask_center=mask_center, ) if mask_center is True: pad, net_low_freqs = varnet.sens_net.get_pad_and_num_low_freqs( mask, num_low_freqs) assert net_low_freqs == num_low_freqs assert torch.allclose( mask.squeeze()[int(pad):int(pad + net_low_freqs)].to(torch.int8), torch.ones([int(net_low_freqs)], dtype=torch.int8), ) y = varnet(output, mask.byte(), num_low_frequencies=4) assert y.shape[1:] == x.shape[2:4]
def __getitem__(self, i: int): fname, dataslice, metadata = self.examples[i] with h5py.File(fname, "r") as hf: kspace = hf["kspace"][dataslice] ###################################################### mask_func = RandomMaskFunc(center_fractions=[0.04], accelerations=[8]) # Create the mask function object masked_kspace, mask = T.apply_mask(T.to_tensor(kspace), mask_func) # Apply the mask to k-space loss_masked_kspace = masked_kspace[:,:,0].numpy() #print("kspace shape : ", kspace.shape) #print("loss_masked_kspace shape : ", loss_masked_kspace.shape) trn_masked_kspace = kspace - loss_masked_kspace ###################################################### mask = np.asarray(hf["mask"]) if "mask" in hf else None target = hf[self.recons_key][dataslice] if self.recons_key in hf else None ###################################################### kspace = trn_masked_kspace target = loss_masked_kspace ###################################################### attrs = dict(hf.attrs) attrs.update(metadata) if self.transform is None: sample = (kspace, mask, target, attrs, fname.name, dataslice) else: sample = self.transform(kspace, mask, target, attrs, fname.name, dataslice) return sample
def test_varnet(shape, chans, center_fractions, accelerations, mask_center): mask_func = RandomMaskFunc(center_fractions, accelerations) x = create_input(shape) outputs, masks = [], [] for i in range(x.shape[0]): output, mask, _ = transforms.apply_mask(x[i:i + 1], mask_func, seed=123) outputs.append(output) masks.append(mask) output = torch.cat(outputs) mask = torch.cat(masks) varnet = VarNet( num_cascades=2, sens_chans=4, sens_pools=2, chans=chans, pools=2, mask_center=mask_center, ) y = varnet(output, mask.byte()) assert y.shape[1:] == x.shape[2:4]
def load_data_from_pathlist(path): file_num = len(path) use_num = file_num // 3 total_target_list = [] total_sampled_image_list = [] for h5_num in range(use_num): total_kspace, slices_num, target = load_dataset(path[h5_num]) image_list = [] slice_kspace_tensor_list = [] target_image_list = [] for i in range(slices_num): slice_kspace = total_kspace[i] #target_image = target[i] slice_kspace_tensor = T.to_tensor( slice_kspace) # convert numpy to tensor slice_kspace_tensor = slice_kspace_tensor.float() #print(slice_kspace_tensor.shape) slice_kspace_tensor_list.append( slice_kspace_tensor) # 35* torch[640, 368]) #target = target_image_list.append(target_image) #image_list_tensor = torch.stack(image_list, dim=0) # torch.Size([35, 640, 368]) #total_image_list.append(image_list_tensor) mask_func = RandomMaskFunc( center_fractions=[0.08], accelerations=[4]) # create the mask function object sampled_image_list = [] target_list = [] for i in range(slices_num): slice_kspace_tensor = slice_kspace_tensor_list[i] masked_kspace, mask = T.apply_mask(slice_kspace_tensor, mask_func) Ny, Nx, _ = slice_kspace_tensor.shape mask = mask.repeat(Ny, 1, 1).squeeze() # functions.show_slice(mask, cmap='gray') # functions.show_slice(image_list[10], cmap='gray') sampled_image = fastmri.ifft2c( masked_kspace) # inverse fast FT to get the complex image sampled_image = T.complex_center_crop(sampled_image, (320, 320)) sampled_image_abs = fastmri.complex_abs(sampled_image) sampled_image_list.append(sampled_image_abs) sampled_image_list_tensor = torch.stack( sampled_image_list, dim=0) # torch.Size([35, 640, 368]) total_sampled_image_list.append(sampled_image_list_tensor) target = T.to_tensor(target) total_target_list.append(target) #target_image_tensor = torch.cat(target_image_list, dim=0) # torch.Size([6965, 640, 368]) total_target = torch.cat(total_target_list, dim=0) total_sampled_image_tensor = torch.cat( total_sampled_image_list, dim=0) # torch.Size([6965, 640, 368]) total_sampled_image_tensor, mean, std = T.normalize_instance( total_sampled_image_tensor, eps=1e-11) total_sampled_image_tensor = total_sampled_image_tensor.clamp(-6, 6) target_image_tensor = T.normalize(total_target, mean, std, eps=1e-11) target_image_tensor = target_image_tensor.clamp(-6, 6) # total_image_tensor = torch.stack(total_image_list, dim=0) # torch.Size([199, 35, 640, 368]) # total_sampled_image_tensor = torch.stack(total_sampled_image_list, dim=0) # torch.Size([199, 35, 640, 368]) #print(target_image_tensor.shape) #print(total_sampled_image_tensor.shape) return target_image_tensor, total_sampled_image_tensor
def load_data(file_dir_path): file_path = get_files(file_dir_path) file_num = len(file_path) total_image_list = [] total_sampled_image_list = [] for h5_num in range(file_num): total_kspace, slices_num = load_dataset(file_path[0]) image_list = [] slice_kspace_tensor_list = [] for i in range(slices_num): slice_kspace = total_kspace[i] slice_kspace_tensor = T.to_tensor( slice_kspace) # convert numpy to tensor slice_image = fastmri.ifft2c( slice_kspace_tensor) # inverse fast FT slice_image_abs = fastmri.complex_abs( slice_image) # compute the absolute value to get a real image image_list.append(slice_image_abs) slice_kspace_tensor_list.append( slice_kspace_tensor) # 35* torch[640, 368]) image_list_tensor = torch.stack(image_list, dim=0) # torch.Size([35, 640, 368]) total_image_list.append(image_list_tensor) mask_func = RandomMaskFunc( center_fractions=[0.08], accelerations=[4]) # create the mask function object sampled_image_list = [] for i in range(slices_num): slice_kspace_tensor = slice_kspace_tensor_list[i] masked_kspace, mask = T.apply_mask(slice_kspace_tensor, mask_func) Ny, Nx, _ = slice_kspace_tensor.shape mask = mask.repeat(Ny, 1, 1).squeeze() # functions.show_slice(mask, cmap='gray') # functions.show_slice(image_list[10], cmap='gray') sampled_image = fastmri.ifft2c( masked_kspace) # inverse fast FT to get the complex image sampled_image_abs = fastmri.complex_abs(sampled_image) sampled_image_list.append(sampled_image_abs) sampled_image_list_tensor = torch.stack( sampled_image_list, dim=0) # torch.Size([35, 640, 368]) total_sampled_image_list.append(sampled_image_list_tensor) # total_image_tensor = torch.cat(total_image_list, dim=0) # torch.Size([6965, 640, 368]) # total_sampled_image_tensor = torch.cat(total_sampled_image_list, dim=0) # torch.Size([6965, 640, 368]) total_image_tensor = torch.stack(total_image_list, dim=0) # torch.Size([199, 35, 640, 368]) total_sampled_image_tensor = torch.stack( total_sampled_image_list, dim=0) # torch.Size([199, 35, 640, 368]) print(total_image_tensor.shape) print(total_sampled_image_tensor.shape) return total_image_tensor, total_sampled_image_tensor
def test_apply_mask(shape, center_fractions, accelerations): state = np.random.get_state() mask_func = RandomMaskFunc(center_fractions, accelerations) expected_mask = mask_func(shape, seed=123) x = create_input(shape) output, mask = transforms.apply_mask(x, mask_func, seed=123) assert (state[1] == np.random.get_state()[1]).all() assert output.shape == x.shape assert mask.shape == expected_mask.shape assert np.all(expected_mask.numpy() == mask.numpy()) assert np.all( np.where(mask.numpy() == 0, 0, output.numpy()) == output.numpy())
def test_varnet(shape, out_chans, chans, center_fractions, accelerations): mask_func = RandomMaskFunc(center_fractions, accelerations) x = create_input(shape) output, mask = transforms.apply_mask(x, mask_func, seed=123) varnet = VarNet(num_cascades=2, sens_chans=4, sens_pools=2, chans=4, pools=2) y = varnet(output, mask.byte()) assert y.shape[1:] == x.shape[2:4]
def test_apply_mask(shape, center_fractions, accelerations): state = np.random.get_state() mask_func = RandomMaskFunc(center_fractions, accelerations) expected_mask, expected_num_low_frequencies = mask_func(shape, seed=123) assert expected_num_low_frequencies in [ round(cf * shape[-2]) for cf in center_fractions ] x = create_input(shape) output, mask, num_low_frequencies = transforms.apply_mask(x, mask_func, seed=123) assert (state[1] == np.random.get_state()[1]).all() assert output.shape == x.shape assert mask.shape == expected_mask.shape assert np.all(expected_mask.numpy() == mask.numpy()) assert np.all( np.where(mask.numpy() == 0, 0, output.numpy()) == output.numpy()) assert num_low_frequencies == expected_num_low_frequencies
slice_image_rss = fastmri.rss(slice_image_abs, dim=0) # In[17]: plt.imshow(np.abs(slice_image_rss.numpy()), cmap='gray') # So far, we have been looking at fully-sampled data. We can simulate under-sampled data by creating a mask and applying it to k-space. # In[18]: from fastmri.data.subsample import RandomMaskFunc mask_func = RandomMaskFunc(center_fractions=[0.08], accelerations=[4]) # Create the mask function object # In[19]: masked_kspace, mask = T.apply_mask(slice_kspace2, mask_func) # Apply the mask to k-space # Let's see what the subsampled image looks like: # In[20]: sampled_image = fastmri.ifft2c(masked_kspace) # Apply Inverse Fourier Transform to get the complex image sampled_image_abs = fastmri.complex_abs(sampled_image) # Compute absolute value to get a real image