예제 #1
0
def test_varnet_num_sense_lines(shape, chans, center_fractions, accelerations,
                                mask_center):
    mask_func = RandomMaskFunc(center_fractions, accelerations)
    x = create_input(shape)
    output, mask, num_low_freqs = transforms.apply_mask(x, mask_func, seed=123)

    varnet = VarNet(
        num_cascades=2,
        sens_chans=4,
        sens_pools=2,
        chans=chans,
        pools=2,
        mask_center=mask_center,
    )

    if mask_center is True:
        pad, net_low_freqs = varnet.sens_net.get_pad_and_num_low_freqs(
            mask, num_low_freqs)
        assert net_low_freqs == num_low_freqs
        assert torch.allclose(
            mask.squeeze()[int(pad):int(pad + net_low_freqs)].to(torch.int8),
            torch.ones([int(net_low_freqs)], dtype=torch.int8),
        )

    y = varnet(output, mask.byte(), num_low_frequencies=4)

    assert y.shape[1:] == x.shape[2:4]
예제 #2
0
    def __getitem__(self, i: int):
        fname, dataslice, metadata = self.examples[i]

        with h5py.File(fname, "r") as hf:
            kspace = hf["kspace"][dataslice]

######################################################

            mask_func = RandomMaskFunc(center_fractions=[0.04], accelerations=[8])  # Create the mask function object
            masked_kspace, mask = T.apply_mask(T.to_tensor(kspace), mask_func)   # Apply the mask to k-space
            loss_masked_kspace = masked_kspace[:,:,0].numpy()
            #print("kspace shape  : ", kspace.shape)
            #print("loss_masked_kspace shape : ", loss_masked_kspace.shape)
            trn_masked_kspace = kspace - loss_masked_kspace

######################################################

            mask = np.asarray(hf["mask"]) if "mask" in hf else None

            target = hf[self.recons_key][dataslice] if self.recons_key in hf else None

######################################################
            kspace = trn_masked_kspace
            target = loss_masked_kspace
######################################################

            attrs = dict(hf.attrs)
            attrs.update(metadata)

        if self.transform is None:
            sample = (kspace, mask, target, attrs, fname.name, dataslice)
        else:
            sample = self.transform(kspace, mask, target, attrs, fname.name, dataslice)

        return sample
예제 #3
0
def test_varnet(shape, chans, center_fractions, accelerations, mask_center):
    mask_func = RandomMaskFunc(center_fractions, accelerations)
    x = create_input(shape)
    outputs, masks = [], []
    for i in range(x.shape[0]):
        output, mask, _ = transforms.apply_mask(x[i:i + 1],
                                                mask_func,
                                                seed=123)
        outputs.append(output)
        masks.append(mask)

    output = torch.cat(outputs)
    mask = torch.cat(masks)

    varnet = VarNet(
        num_cascades=2,
        sens_chans=4,
        sens_pools=2,
        chans=chans,
        pools=2,
        mask_center=mask_center,
    )

    y = varnet(output, mask.byte())

    assert y.shape[1:] == x.shape[2:4]
예제 #4
0
def load_data_from_pathlist(path):
    file_num = len(path)
    use_num = file_num // 3
    total_target_list = []
    total_sampled_image_list = []
    for h5_num in range(use_num):
        total_kspace, slices_num, target = load_dataset(path[h5_num])
        image_list = []
        slice_kspace_tensor_list = []
        target_image_list = []
        for i in range(slices_num):
            slice_kspace = total_kspace[i]
            #target_image = target[i]
            slice_kspace_tensor = T.to_tensor(
                slice_kspace)  # convert numpy to tensor
            slice_kspace_tensor = slice_kspace_tensor.float()
            #print(slice_kspace_tensor.shape)
            slice_kspace_tensor_list.append(
                slice_kspace_tensor)  # 35* torch[640, 368])
            #target = target_image_list.append(target_image)

        #image_list_tensor = torch.stack(image_list, dim=0)  # torch.Size([35, 640, 368])
        #total_image_list.append(image_list_tensor)
        mask_func = RandomMaskFunc(
            center_fractions=[0.08],
            accelerations=[4])  # create the mask function object
        sampled_image_list = []
        target_list = []
        for i in range(slices_num):
            slice_kspace_tensor = slice_kspace_tensor_list[i]
            masked_kspace, mask = T.apply_mask(slice_kspace_tensor, mask_func)
            Ny, Nx, _ = slice_kspace_tensor.shape
            mask = mask.repeat(Ny, 1, 1).squeeze()
            # functions.show_slice(mask, cmap='gray')
            # functions.show_slice(image_list[10], cmap='gray')
            sampled_image = fastmri.ifft2c(
                masked_kspace)  # inverse fast FT to get the complex image
            sampled_image = T.complex_center_crop(sampled_image, (320, 320))
            sampled_image_abs = fastmri.complex_abs(sampled_image)
            sampled_image_list.append(sampled_image_abs)
        sampled_image_list_tensor = torch.stack(
            sampled_image_list, dim=0)  # torch.Size([35, 640, 368])
        total_sampled_image_list.append(sampled_image_list_tensor)
        target = T.to_tensor(target)
        total_target_list.append(target)
    #target_image_tensor = torch.cat(target_image_list, dim=0)                       # torch.Size([6965, 640, 368])
    total_target = torch.cat(total_target_list, dim=0)
    total_sampled_image_tensor = torch.cat(
        total_sampled_image_list, dim=0)  # torch.Size([6965, 640, 368])
    total_sampled_image_tensor, mean, std = T.normalize_instance(
        total_sampled_image_tensor, eps=1e-11)
    total_sampled_image_tensor = total_sampled_image_tensor.clamp(-6, 6)
    target_image_tensor = T.normalize(total_target, mean, std, eps=1e-11)
    target_image_tensor = target_image_tensor.clamp(-6, 6)
    # total_image_tensor = torch.stack(total_image_list, dim=0)  # torch.Size([199, 35, 640, 368])
    # total_sampled_image_tensor = torch.stack(total_sampled_image_list, dim=0)  # torch.Size([199, 35, 640, 368])
    #print(target_image_tensor.shape)
    #print(total_sampled_image_tensor.shape)
    return target_image_tensor, total_sampled_image_tensor
예제 #5
0
def load_data(file_dir_path):
    file_path = get_files(file_dir_path)
    file_num = len(file_path)
    total_image_list = []
    total_sampled_image_list = []
    for h5_num in range(file_num):
        total_kspace, slices_num = load_dataset(file_path[0])
        image_list = []
        slice_kspace_tensor_list = []
        for i in range(slices_num):
            slice_kspace = total_kspace[i]
            slice_kspace_tensor = T.to_tensor(
                slice_kspace)  # convert numpy to tensor
            slice_image = fastmri.ifft2c(
                slice_kspace_tensor)  # inverse fast FT
            slice_image_abs = fastmri.complex_abs(
                slice_image)  # compute the absolute value to get a real image
            image_list.append(slice_image_abs)
            slice_kspace_tensor_list.append(
                slice_kspace_tensor)  # 35* torch[640, 368])

        image_list_tensor = torch.stack(image_list,
                                        dim=0)  # torch.Size([35, 640, 368])
        total_image_list.append(image_list_tensor)
        mask_func = RandomMaskFunc(
            center_fractions=[0.08],
            accelerations=[4])  # create the mask function object
        sampled_image_list = []
        for i in range(slices_num):
            slice_kspace_tensor = slice_kspace_tensor_list[i]
            masked_kspace, mask = T.apply_mask(slice_kspace_tensor, mask_func)
            Ny, Nx, _ = slice_kspace_tensor.shape
            mask = mask.repeat(Ny, 1, 1).squeeze()
            # functions.show_slice(mask, cmap='gray')
            # functions.show_slice(image_list[10], cmap='gray')
            sampled_image = fastmri.ifft2c(
                masked_kspace)  # inverse fast FT to get the complex image
            sampled_image_abs = fastmri.complex_abs(sampled_image)
            sampled_image_list.append(sampled_image_abs)
        sampled_image_list_tensor = torch.stack(
            sampled_image_list, dim=0)  # torch.Size([35, 640, 368])
        total_sampled_image_list.append(sampled_image_list_tensor)
    # total_image_tensor = torch.cat(total_image_list, dim=0)                       # torch.Size([6965, 640, 368])
    # total_sampled_image_tensor = torch.cat(total_sampled_image_list, dim=0)       # torch.Size([6965, 640, 368])
    total_image_tensor = torch.stack(total_image_list,
                                     dim=0)  # torch.Size([199, 35, 640, 368])
    total_sampled_image_tensor = torch.stack(
        total_sampled_image_list, dim=0)  # torch.Size([199, 35, 640, 368])
    print(total_image_tensor.shape)
    print(total_sampled_image_tensor.shape)
    return total_image_tensor, total_sampled_image_tensor
예제 #6
0
def test_apply_mask(shape, center_fractions, accelerations):
    state = np.random.get_state()

    mask_func = RandomMaskFunc(center_fractions, accelerations)
    expected_mask = mask_func(shape, seed=123)
    x = create_input(shape)
    output, mask = transforms.apply_mask(x, mask_func, seed=123)

    assert (state[1] == np.random.get_state()[1]).all()
    assert output.shape == x.shape
    assert mask.shape == expected_mask.shape
    assert np.all(expected_mask.numpy() == mask.numpy())
    assert np.all(
        np.where(mask.numpy() == 0, 0, output.numpy()) == output.numpy())
예제 #7
0
def test_varnet(shape, out_chans, chans, center_fractions, accelerations):
    mask_func = RandomMaskFunc(center_fractions, accelerations)
    x = create_input(shape)
    output, mask = transforms.apply_mask(x, mask_func, seed=123)

    varnet = VarNet(num_cascades=2,
                    sens_chans=4,
                    sens_pools=2,
                    chans=4,
                    pools=2)

    y = varnet(output, mask.byte())

    assert y.shape[1:] == x.shape[2:4]
예제 #8
0
def test_apply_mask(shape, center_fractions, accelerations):
    state = np.random.get_state()

    mask_func = RandomMaskFunc(center_fractions, accelerations)
    expected_mask, expected_num_low_frequencies = mask_func(shape, seed=123)
    assert expected_num_low_frequencies in [
        round(cf * shape[-2]) for cf in center_fractions
    ]
    x = create_input(shape)
    output, mask, num_low_frequencies = transforms.apply_mask(x,
                                                              mask_func,
                                                              seed=123)

    assert (state[1] == np.random.get_state()[1]).all()
    assert output.shape == x.shape
    assert mask.shape == expected_mask.shape
    assert np.all(expected_mask.numpy() == mask.numpy())
    assert np.all(
        np.where(mask.numpy() == 0, 0, output.numpy()) == output.numpy())
    assert num_low_frequencies == expected_num_low_frequencies
예제 #9
0
slice_image_rss = fastmri.rss(slice_image_abs, dim=0)


# In[17]:


plt.imshow(np.abs(slice_image_rss.numpy()), cmap='gray')


# So far, we have been looking at fully-sampled data. We can simulate under-sampled data by creating a mask and applying it to k-space.

# In[18]:


from fastmri.data.subsample import RandomMaskFunc
mask_func = RandomMaskFunc(center_fractions=[0.08], accelerations=[4])  # Create the mask function object


# In[19]:


masked_kspace, mask = T.apply_mask(slice_kspace2, mask_func)   # Apply the mask to k-space


# Let's see what the subsampled image looks like:

# In[20]:


sampled_image = fastmri.ifft2c(masked_kspace)           # Apply Inverse Fourier Transform to get the complex image
sampled_image_abs = fastmri.complex_abs(sampled_image)   # Compute absolute value to get a real image