def train_epoch(model, optimizer, train_loader):
    model.train()
    avg_loss = 0.
    start_epoch = start_iter = time.perf_counter()
    for iteration, sample in enumerate(train_loader):
        
        img_gt, img_und, rawdata_und, masks, norm = sample
         
        # Extract input and ground truth image
        input_img = T.complex_abs(img_und).squeeze()
        input_img = T.center_crop(input_img, [320, 320])
        input_img = input_img[None,None,:,:].cuda()
        target_img = T.complex_abs(img_gt).squeeze()
        target_img = T.center_crop(target_img, [320, 320])
        target_img = target_img[None,None,:,:].cuda()
        
        output = model(input_img)
        loss = F.l1_loss(output, target_img)
        #loss = ssim(target_img, output)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        avg_loss = 0.99 * avg_loss + 0.01 * loss.item() if iteration > 0 else loss.item()
    
    return avg_loss, time.perf_counter() - start_epoch
def get_epoch_batch(subject_id, acc, center_fract, is_undersampled, use_seed=True):
    ''' random select a few slices (batch_size) from each volume'''

    fname, rawdata_name, slice = subject_id  
    
    with h5py.File(rawdata_name, 'r') as data:
        if(is_undersampled):
            rawdata = data['kspace_8af'][slice]
        else:
            rawdata = data['kspace'][slice]
                      
    slice_kspace = T.to_tensor(rawdata).unsqueeze(0)
    S, Ny, Nx, ps = slice_kspace.shape

    # apply random mask
    shape = np.array(slice_kspace.shape)
    mask_func = MaskFunc(center_fractions=[center_fract], accelerations=[acc])
    seed = None if not use_seed else tuple(map(ord, fname))
    mask = mask_func(shape, seed)
      
    # undersample
    masked_kspace = torch.where(mask == 0, torch.Tensor([0]), slice_kspace)
    masks = mask.repeat(S, Ny, 1, ps)

    img_gt, img_und = T.ifft2(slice_kspace), T.ifft2(masked_kspace)

    # perform data normalization which is important for network to learn useful features
    # during inference there is no ground truth image so use the zero-filled recon to normalize
    norm = T.complex_abs(img_und).max()
    if norm < 1e-6: norm = 1e-6
    
    # normalized data
    img_gt, img_und, rawdata_und = img_gt/norm, img_und/norm, masked_kspace/norm
        
    return img_gt.squeeze(0), img_und.squeeze(0), rawdata_und.squeeze(0), masks.squeeze(0), norm
def get_epoch_batch(subject_id, acc, center_fract, use_seed=True):
    ''' randomly select a few slices (batch_size) from each volume'''

    fname, rawdata_name, slice = subject_id

    with h5py.File(rawdata_name, 'r') as data:
        rawdata = data['kspace'][slice]

    slice_kspace = T.to_tensor(rawdata).unsqueeze(0)
    S, Ny, Nx, ps = slice_kspace.shape

    # apply random mask
    shape = np.array(slice_kspace.shape)
    mask_func = MaskFunc(center_fractions=[center_fract], accelerations=[acc])
    seed = None if not use_seed else tuple(map(ord, fname))
    mask = mask_func(shape, seed)

    # undersample
    masked_kspace = torch.where(mask == 0, torch.Tensor([0]), slice_kspace)
    masks = mask.repeat(S, Ny, 1, ps)

    # normalise data
    img_gt, img_und = T.ifft2(slice_kspace), T.ifft2(masked_kspace)
    norm = T.complex_abs(img_und).max()
    if norm < 1e-6: norm = 1e-6

    img_gt, img_und, rawdata_und = img_gt / norm, img_und / norm, masked_kspace / norm

    return img_gt.squeeze(0), img_und.squeeze(0), rawdata_und.squeeze(
        0), masks.squeeze(0), norm
def val_epoch(model, test_loader):
    #Setting model to evaluation mode
    model.eval()
    start_epoch = start_iter = time.perf_counter()
    losses = []
    
    for iteration, sample in enumerate(test_loader):
        img_gt, img_und, rawdata_und, masks, norm = sample
         
        # Extract input and ground truth image
        input_img = T.complex_abs(img_und).squeeze()
        input_img = T.center_crop(input_img, [320, 320])
        input_img = input_img[None,None,:,:].cuda()
        target_img = T.complex_abs(img_gt).squeeze()
        target_img = T.center_crop(target_img, [320, 320])
        target_img = target_img[None,None,:,:].cuda()
        
        output = model(input_img)
        loss = F.l1_loss(output, target_img)
        losses.append(loss.item())

    return np.mean(losses), time.perf_counter() - start_epoch
def eval_model(model):
    #Testing the trained model on test set - 20%
    data_path = '/data/local/NC2019MRI/train/'
    all_files = sorted(os.listdir(data_path))
    #Using 20% of data as test set
    all_files = all_files[round(len(all_files)*0.8):]

    #acc and cent_fract should be (8, 0.04) or (4, 0.08)
    acc = 8
    cen_fract = 0.04
    # random masks for each slice 
    seed = False
    # data loading is faster using a bigger number for num_workers. 0 means using one cpu to load data
    num_workers = 12 

    all_losses = []

    for file in all_files:
        # first load all file names, paths and slices.
        data_list = load_test_data_path(data_path, file)
        #Create a dataloader containing all slices for particular slice
        test_dataset = MRIDataset(data_list['test'], acceleration=acc, center_fraction=cen_fract, use_seed=seed,is_undersampled=False)
        test_loader = DataLoader(test_dataset, shuffle=False, batch_size=1, num_workers=num_workers)
        #Call model over all slices. pred_vol is a list of volume slices (in correct order)
        pred_vol,test_time = test_epoch(model, test_loader)
        
        #3d Volume achieved by stacking slices
        pred = torch.stack(pred_vol, dim=0)
        
        #Loading Ground Truth
        with h5py.File(data_path + file,  "r") as hf:
            volume_kspace = hf['kspace'][()]
            
        #Transform k-space to real valued image
        volume_kspace2 = T.to_tensor(volume_kspace)
        # Apply Inverse Fourier Transform to get the complex image
        volume_image = T.ifft2(volume_kspace2)            
        volume_image_abs = T.complex_abs(volume_image)
        volume_image_abs = volume_image_abs[5:, :, :]
        real_gt = T.center_crop(volume_image_abs, [320, 320])
        
        #visualize predicted and gt slices
        show_slices(pred, [5, 10, 20, pred.shape[0]-1], cmap='gray')
        show_slices(real_gt, [5, 10, 20, pred.shape[0]-1], cmap='gray')
        
        #calculate ssim score between volumes
        ssim_score = ssim(real_gt.numpy(), pred.numpy())
        all_losses.append(ssim_score)
    
    average_ssim = sum(all_losses) / len(all_losses)
    return average_ssim
def test_epoch(model, test_loader):
    model.eval()
    start_epoch = start_iter = time.perf_counter()
    losses = []
    output_vol = []
    gt_vol = []
    
    for iteration, sample in enumerate(test_loader):
        img_gt, img_und, rawdata_und, masks, norm = sample
         
        # Extract input and ground truth image
        input_img = T.complex_abs(img_und).squeeze()
        input_img = T.center_crop(input_img, [320, 320])
        input_img = input_img[None,None,:,:].cuda()
        
        output = model(input_img)
        output = output[0,0,:,:].cpu().detach()
        output_vol.append(output*norm)
        
    #return np.mean(losses), time.perf_counter() - start_epoch
    return output_vol, time.perf_counter() - start_epoch
Exemple #7
0
    fig = plt.figure(figsize=(15,10))
    for i, num in enumerate(slice_nums):
        plt.subplot(1, len(slice_nums), i + 1)
        plt.imshow(data[num], cmap=cmap)
        plt.axis('off')
        plt.show()

show_slices(np.log(np.abs(volume_kspace) + 1e-9), [0, 10, 20, 30], cmap='gray')

# Show slices in the sample as real images

from functions import transforms as T

volume_kspace2 = T.to_tensor(volume_kspace)
volume_image = T.ifft2(volume_kspace2)
volume_image_abs = T.complex_abs(volume_image)

show_slices(volume_image_abs, [0, 10, 20, 30], cmap='gray')

##
## Simulate under-sample data
## 

import torch
from functions.subsample import MaskFunc

# Initiate MaskFunc() for 2 AF respectively
mask_func0 = MaskFunc(center_fractions=[0.08], accelerations=[4])
mask_func1 = MaskFunc(center_fractions=[0.04], accelerations=[8])
# See subsample.py for detailed annotation for class MaskFunc().
    acc = 8
    cen_fract = 0.04
    seed = False  # random masks for each slice
    num_workers = 12  # data loading workers

    # create data loader for training set/ validation set
    train_dataset = MRIDataset(data_list['train'],
                               acceleration=acc,
                               center_fraction=cen_fract,
                               use_seed=seed)
    train_loader = DataLoader(train_dataset,
                              shuffle=True,
                              batch_size=1,
                              num_workers=num_workers)
    for iteration, sample in enumerate(train_loader):

        img_gt, img_und, rawdata_und, masks, norm = sample

        # stack different slices into a volume for visualisation
        A = masks[..., 0].squeeze()
        B = torch.log(T.complex_abs(rawdata_und) + 1e-9).squeeze()
        C = T.complex_abs(img_und).squeeze()
        D = T.complex_abs(img_gt).squeeze()
        all_imgs = torch.stack([A, B, C, D], dim=0)

        # mask, masked space, undersampled image, ground truth
        show_slices(all_imgs, [0, 1, 2, 3], cmap='gray')
        plt.pause(1)

        if iteration >= 3: break  # to show 4 random slices