Example #1
0
def run_expmt(file_id_list):

    for file_id in file_id_list:

        if os.path.exists('{}{}_dc.npy'.format(path_out, file_id)):
            continue

        f, ksp_orig = load_h5_fastmri(file_id)
        ksp_orig = torch.from_numpy(ksp_orig)

        mask = get_mask(ksp_orig)

        net, net_input, ksp_orig_ = init_convdecoder(ksp_orig, mask)

        ksp_masked = 0.1 * ksp_orig_ * mask 

        net = fit(ksp_masked, net, net_input, mask)

        img_out = net(net_input.type(dtype))
        img_out = reshape_adj_channels_to_complex_vals(img_out[0])
        ksp_est = fft_2d(img_out)
        ksp_dc = torch.where(mask, ksp_masked, ksp_est)

        #img_est = crop_center(root_sum_squares(ifft_2d(ksp_est)).detach(), dim, dim)
        img_dc = crop_center(root_sum_squares(ifft_2d(ksp_dc)).detach(), dim, dim)
        img_gt = crop_center(root_sum_squares(ifft_2d(ksp_orig)), dim, dim)
        # note: use unscaled ksp_orig to make gt -- different from original processing

        np.save('{}{}_dc.npy'.format(path_out, file_id), img_dc)
        np.save('{}{}_gt.npy'.format(path_out, file_id), img_gt)
Example #2
0
def run_expmt():
    for file_id in file_id_list:

        f, ksp_orig = load_h5(file_id)
        ksp_orig = torch.from_numpy(ksp_orig)

        mask = get_mask(ksp_orig)

        net, net_input, ksp_orig_, _ = init_convdecoder(ksp_orig, mask)

        ksp_masked = SCALE_FAC * ksp_orig_ * mask
        img_masked = ifft_2d(ksp_masked)

        net, mse_wrt_ksp, mse_wrt_img = fit(ksp_masked=ksp_masked,
                                            img_masked=img_masked,
                                            net=net,
                                            net_input=net_input,
                                            mask2d=mask,
                                            num_iter=NUM_ITER)

        img_out, _ = net(net_input.type(dtype))
        img_out = reshape_adj_channels_to_complex_vals(img_out[0])
        ksp_est = fft_2d(img_out)
        ksp_dc = torch.where(mask, ksp_masked, ksp_est)

        img_est = crop_center(
            root_sum_squares(ifft_2d(ksp_est)).detach(), dim, dim)
        img_dc = crop_center(
            root_sum_squares(ifft_2d(ksp_dc)).detach(), dim, dim)
        img_gt = crop_center(root_sum_squares(ifft_2d(ksp_orig)), dim, dim)

        np.save('{}{}_est.npy'.format(path_out, file_id), img_est)
        np.save('{}{}_dc.npy'.format(path_out, file_id), img_dc)
        np.save('{}{}_gt.npy'.format(path_out, file_id), img_gt)
Example #3
0
def run_demo():

    ksp_orig = load_h5_fastmri(file_id=None, demo=True)

    mask = get_mask(ksp_orig)

    net, net_input, ksp_orig_ = init_convdecoder(ksp_orig)

    ksp_masked = 0.1 * ksp_orig_ * mask

    net = fit(ksp_masked, net, net_input, mask)

    img_out = net(net_input.type(dtype))
    img_out = reshape_adj_channels_to_complex_vals(img_out[0])
    ksp_est = fft_2d(img_out)
    ksp_dc = torch.where(mask, ksp_masked, ksp_est)

    img_dc = crop_center(root_sum_squares(ifft_2d(ksp_dc)).detach(), dim, dim)
    img_gt = crop_center(root_sum_squares(ifft_2d(ksp_orig)), dim, dim)
    img_zf = crop_center(root_sum_squares(ifft_2d(ksp_masked)), dim, dim)

    np.save('data/out.npy', img_dc)
    np.save('data/gt.npy', img_gt)
    np.save('data/zf.npy', img_zf)
Example #4
0
def get_scale_factor(net, net_input, ksp_orig):
    ''' return scaling factor, i.e. difference in magnitudes scaling b/w:
        original image and random image of network output = net(net_input) '''

    # generate random img
    net_output = net(net_input.type(dtype))
    net_output = net_output[0] if type(net_output) is tuple else net_output
    out = torch.from_numpy(net_output.data.cpu().numpy()[0])
    out = reshape_adj_channels_to_complex_vals(out)
    out_img = root_sum_squares(out)

    # get img of input sample
    orig = ifft_2d(ksp_orig)
    orig_img = root_sum_squares(orig)

    return torch.linalg.norm(out_img) / torch.linalg.norm(orig_img)
Example #5
0
def forwardm(img, mask, mask2=None):
    ''' convert img --> ksp (must be complex for fft), apply mask
        convert back to img. input dim [2*nc,x,y], output dim [1,2*nc,x,y] 
        
        if adj (real-valued) channels:
            we have 2*nc, [re(e1) | re(e2) | im(e1) | im(e2)] 
        elif complex channels:
            we have nc, [re+im(e1) | re+im(e2)] '''

    img = reshape_adj_channels_to_complex_vals(img[0])
    ksp = fft_2d(img).cuda()

    if mask2 == None:
        ksp_masked_ = ksp * mask
    else:  # apply dual masks, i.e. mask to e1, e2 separately
        assert ksp.shape == (16, 512, 160)
        ksp_m_1 = ksp[:8] * mask
        ksp_m_2 = ksp[8:] * mask2
        ksp_masked_ = torch.cat((ksp_m_1, ksp_m_2), 0)

    img_masked_ = ifft_2d(ksp_masked_)

    return reshape_complex_vals_to_adj_channels(img_masked_)[None, :]
Example #6
0
def run_expmt(args):

    for file_id in args.file_id_list:

        ksp_orig = load_qdess(
            file_id, idx_kx=None)  # default central slice in kx (axial)

        for accel in args.accel_list:

            # manage paths for input/output
            path_base = '/bmrNAS/people/dvv/out_qdess/accel_{}x/'.format(accel)
            path_out = '{}{}/'.format(path_base, args.dir_out)
            args.path_gt = path_base + 'gt/'
            if os.path.exists('{}MTR_{}_e1.npy'.format(path_out, file_id)):
                continue
            if not os.path.exists(path_out):
                os.makedirs(path_out)
            if not os.path.exists(args.path_gt):
                os.makedirs(args.path_gt)

            # initialize network
            net1, net_input1, ksp_orig_ = init_convdecoder(
                ksp_orig, fix_random_seed=False)
            net2, net_input2, _ = init_convdecoder(ksp_orig,
                                                   fix_random_seed=False)

            # apply mask after rescaling k-space. want complex tensors dim (nc, ky, kz)
            ksp_masked, mask = apply_mask(ksp_orig_,
                                          accel,
                                          calib=args.calib,
                                          expmt=True)

            # fit network, get net output - default 10k iterations, lam_tv=1e-8
            net1, net2 = fit(ksp_masked=ksp_masked,
                             net1=net1,
                             net_input1=net_input1,
                             net2=net2,
                             net_input2=net_input2,
                             mask=mask)
            im_out1 = net1(
                net_input1.type(dtype))  # real tensor dim (2*nc, kx, ky)
            im_out2 = net2(
                net_input2.type(dtype))  # real tensor dim (2*nc, kx, ky)
            im_out = torch.mean(torch.stack([im_out1, im_out2]), dim=0)
            im_out = reshape_adj_channels_to_complex_vals(
                im_out[0])  # complex tensor dim (nc, kx, ky)

            # perform dc step
            ksp_est = fft_2d(im_out)
            ksp_dc = torch.where(mask, ksp_masked, ksp_est)
            np.save('{}/MTR_{}_ksp_dc.npy'.format(path_out, file_id),
                    ksp_dc.detach().numpy())

            # create data-consistent, ground-truth images from k-space
            im_1_dc = root_sum_squares(ifft_2d(ksp_dc[:8])).detach()
            im_2_dc = root_sum_squares(ifft_2d(ksp_dc[8:])).detach()
            np.save('{}MTR_{}_e1.npy'.format(path_out, file_id), im_1_dc)
            np.save('{}MTR_{}_e2.npy'.format(path_out, file_id), im_2_dc)

            # save gt w proper array scaling if dne
            if not os.path.exists('{}MTR_{}_e1_gt.npy'.format(
                    args.path_gt, file_id)):
                im_1_gt = root_sum_squares(ifft_2d(ksp_orig[:8]))
                im_2_gt = root_sum_squares(ifft_2d(ksp_orig[8:]))
                np.save('{}MTR_{}_e1_gt.npy'.format(args.path_gt, file_id),
                        im_1_gt)
                np.save('{}MTR_{}_e2_gt.npy'.format(args.path_gt, file_id),
                        im_2_gt)

            print('recon {}'.format(file_id))

    return
Example #7
0
def run_expmt():

    path_in = '/bmrNAS/people/arjun/data/qdess_knee_2020/files_recon_calib-16/'
    #files = [f for f in listdir(path_in) if isfile(join(path_in, f))]
    #files.sort()
    #NUM_SAMPS = 10 # number of samples to recon
       
    NUM_ITER = 10000
    ACCEL_LIST = [8] # 4, 6, 8]

    for fn in test_set: #files[:NUM_SAMPS]:

       # load data
        f = h5py.File(path_in + fn, 'r')
        try:
            ksp = torch.from_numpy(f['kspace'][()])
        except KeyError:
            print('No kspace in file {} w keys {}'.format(fn, f.keys()))
            f.close()
            continue
        f.close()

        # NOTE: if change to echo2, must manually change path nomenclature
        ksp_vol = ksp[:,:,:,0,:].permute(3,0,1,2) # get echo1, reshape to be (nc, kx, ky, kz)

        # get central slice in kx, i.e. axial plane b/c we undersample in (ky, kz)
        idx_kx = ksp_vol.shape[1] // 2
        ksp_orig = ksp_vol[:, idx_kx, :, :]

        for ACCEL in ACCEL_LIST:
           
            path_out = '/bmrNAS/people/dvv/out_qdess/accel_{}x/echo1/'.format(ACCEL)
            
            # original masks created w central region 32x32 forced to 1's
            mask = torch.from_numpy(np.load('ipynb/masks/mask_poisson_disc_{}x.npy'.format(ACCEL)))

            # initialize network
            net, net_input, ksp_orig_, _ = init_convdecoder(ksp_orig, mask)

            # apply mask after rescaling k-space. want complex tensors dim (nc, ky, kz)
            ksp_masked = ksp_orig_ * mask
            img_masked = ifft_2d(ksp_masked)

            # fit network, get net output
            net, mse_wrt_ksp, mse_wrt_img = fit(
                ksp_masked=ksp_masked, img_masked=img_masked,
                net=net, net_input=net_input, mask2d=mask, num_iter=NUM_ITER)
            img_out, _ = net(net_input.type(dtype)) # real tensor dim (2*nc, kx, ky)
            img_out = reshape_adj_channels_to_complex_vals(img_out[0]) # complex tensor dim (nc, kx, ky)
            
            # perform dc step
            ksp_est = fft_2d(img_out)
            ksp_dc = torch.where(mask, ksp_masked, ksp_est)

            # create data-consistent, ground-truth images from k-space
            img_dc = root_sum_squares(ifft_2d(ksp_dc)).detach()
            img_gt = root_sum_squares(ifft_2d(ksp_orig))

            # save results
            samp = fn.split('.h5')[0] #+ '_echo2' 
            np.save('{}{}_dc.npy'.format(path_out, samp), img_dc)
            np.save('{}{}_gt.npy'.format(path_out, samp), img_gt)

            print('recon {}'.format(samp)) 

    return
Example #8
0
def fit(ksp_masked,
        net,
        net_input,
        mask,
        mask2=None,
        num_iter=10000,
        lr=0.01,
        dtype=torch.cuda.FloatTensor,
        LAMBDA_TV=1e-8):
    ''' fit a network to masked k-space measurement
        args:
            ksp_masked: masked k-space of a single slice. torch variable [1,C,H,W]
            net: original network with randomly initiated weights
            net_input: randomly generated + scaled network input
            mask: 2D mask for undersampling the ksp
            mask2: 2D mask for echo2, if applying dual mask
            num_iter: number of iterations to optimize network
            lr: learning rate
        returns:
            net: the best network, whose output would be in image space
    '''

    # initialize variables
    net_input = net_input.type(dtype)
    best_net = copy.deepcopy(net)
    best_mse = 10000.0
    #mse_wrt_ksp, mse_wrt_img = np.zeros(num_iter), np.zeros(num_iter)

    p = [x for x in net.parameters()]
    optimizer = torch.optim.Adam(p, lr=lr, weight_decay=0)
    mse = torch.nn.MSELoss()

    img_masked = ifft_2d(ksp_masked)

    # convert complex [nc,x,y] --> real [2*nc,x,y] to match w net output
    ksp_masked = reshape_complex_vals_to_adj_channels(ksp_masked).cuda()
    img_masked = reshape_complex_vals_to_adj_channels(img_masked)[
        None, :].cuda()
    mask = mask.cuda()
    if mask2 != None:
        mask2 = mask2.cuda()

    for i in range(num_iter):

        def closure():  # execute this for each iteration (gradient step)

            optimizer.zero_grad()

            out = net(net_input)  # out is in img space

            out_img_masked = forwardm(out, mask,
                                      mask2)  # img-->ksp, mask, convert to img

            loss_img = mse(out_img_masked, img_masked)

            loss_tv = (torch.sum(torch.abs(out_img_masked[:,:,:,:-1] - \
                                           out_img_masked[:,:,:,1:])) \
                     + torch.sum(torch.abs(out_img_masked[:,:,:-1,:] - \
                                           out_img_masked[:,:,1:,:])))
            loss_total = loss_img + LAMBDA_TV * loss_tv

            loss_total.backward(retain_graph=False)

            return loss_total

        loss = optimizer.step(closure)

        # at each iteration, check if loss improves by 1%. if so, a new best net
        loss_val = loss.data
        if best_mse > 1.005 * loss_val:
            best_mse = loss_val
            best_net = copy.deepcopy(net)

    return best_net  #, mse_wrt_ksp, mse_wrt_img