Esempio n. 1
0
 def __getitem__(self, idx):
     global nch
     pack = self.data[idx]
     x = util.ImageRescale(pack[:nch, :, :], [0, 255])
     y = util.ImageRescale(pack[-1, :, :], [0, 1])
     x_tensor, y_tensor = self.ToTensor(x, y)
     return x_tensor, y_tensor
Esempio n. 2
0
def get_train_data(num):
    global msk
    train_data = ()
    for i in range(len(volume)):
        vol_x = util.nii_loader(dataroot + volume[i] + '.nii.gz')
        vol_y = util.nii_loader(dataroot + 'sf(' + volume[i] + ').nii.gz')
        # extract vessel layers
        vol_x = vol_x[:, slc_range[i][0]:slc_range[i][1], :]
        vol_y = vol_y[:, slc_range[i][0]:slc_range[i][1], :]
        # size of en-face slices
        H, slc, W = vol_x.shape

        # iterate over the vessel layers
        for j in range(slc):
            x = util.ImageRescale(vol_x[:, j, :], [0, 255])
            y = util.ImageRescale(vol_y[:, j, :], [0, 1])

            # samples from single image
            for k in range(num):
                pseed = [
                    random.randint(0, H - msk[0]),
                    random.randint(0, W - msk[1])
                ]
                im_x = x[pseed[0]:pseed[0] + msk[0],
                         pseed[1]:pseed[1] + msk[1]]
                im_y = y[pseed[0]:pseed[0] + msk[0],
                         pseed[1]:pseed[1] + msk[1]]
                train_data = train_data + (
                    (im_x, im_y),
                    (np.fliplr(im_x), np.fliplr(im_y)),
                    (np.flipud(im_x), np.flipud(im_y)),
                )
    return train_data
Esempio n. 3
0
def SegVAE(vol,model_dn,model_refine,model_vae):
    global msk
    h,slc,w = vol.shape
    # set batchsize to 4 so that 4 pieces are in the same stack
    test_loader = Data.DataLoader(dataset=vae_test_loader(vol),
                                   batch_size=4, shuffle=False)
    
    # define the output volumes
    vol_seg = np.zeros(vol.shape,dtype=np.float32)
    vol_syn = np.zeros(vol.shape,dtype=np.float32)

    for step,(tensor_x) in enumerate(test_loader):
        # stack size: [4,0,320,320]
        x = Variable(tensor_x).to(device)
        dn_x = model_dn(x)
        _,stack_syn = model_refine(dn_x)
        stack_seg,_ = model_vae(im_syn)

        stack_seg = stack_seg.detach().cpu().numpy()
        stack_syn = stack_syn.detach().cpu().numpy()

        # combine 4 pieces
        for i in range(slc):
            vol_seg[:msk[0],i,:msk[1]] = util.ImageRescale(stack_seg[0,0,:,:],[0,255])
            vol_seg[h-msk[0]:,i,:msk[1]] = util.ImageRescale(stack_seg[1,0,:,:],[0,255])
            vol_seg[:msk[0],i,w-msk[1]:] = util.ImageRescale(stack_seg[2,0,:,:],[0,255])
            vol_seg[h-msk[0]:,i,w-msk[0]:] = util.ImageRescale(stack_seg[3,0,:,:],[0,255])

            vol_syn[:msk[0],i,:msk[1]] = util.ImageRescale(stack_syn[0,0,:,:],[0,255])
            vol_syn[h-msk[0]:,i,:msk[1]] = util.ImageRescale(stack_syn[1,0,:,:],[0,255])
            vol_syn[:msk[0],i,w-msk[1]:] = util.ImageRescale(stack_syn[2,0,:,:],[0,255])
            vol_syn[h-msk[0]:,i,w-msk[0]:] = util.ImageRescale(stack_syn[3,0,:,:],[0,255])

    return vol_seg, vol_syn
Esempio n. 4
0
def MSUN_main(raw_Volume, FrameNum, Ndn, dn_display):
    # Pre-define the output form
    depth = 500 - nch + 1
    opt = np.zeros([FrameNum, depth, 512, 500], dtype=np.float32)

    for idx in range(FrameNum):

        # [1] Frame separation and croping
        volume = PickFrame(raw, FrameNum, idx)
        volume = volume[:, 150:662, :]
        print('Frame {} separated.'.format(idx))

        for slc in range(depth):
            noi = volume[slc, :, :500]
            x = torch.tensor(volume[slc, :, :]).type(torch.FloatTensor)
            x = Variable(x).to(device)
            pred = Ndn(x).detach().cpu().numpy()
            opt[idx, slc, :, :] = util.ImageRescale(pred[0, 0, :, :500],
                                                    [0, 255])

            if slc % 50 == 0 and dn_display == True:
                plt.figure(figsize=(18, 6))
                plt.title('slc:{}'.format(slc))
                plt.axis('off')
                plt.imshow(np.concatenate([noi, opt[idx, slc, :, :]], axis=1),
                           cmap='gray')
                plt.show()

    return opt
Esempio n. 5
0
def merge(h, w, stack, loc):
    global msk
    opt = np.zeros([h, w], dtype=np.float32)
    counter = np.zeros([h, w], dtype=np.float32)

    for i in range(len(loc)):
        x, y = loc[i]
        opt[x:x + msk[0],
            y:y + msk[1]] += util.ImageRescale(stack[i, :, :], [0, 255])
        counter[x:x + msk[0], y:y + msk[1]] += np.ones([msk[0], msk[1]],
                                                       dtype=np.float32)
    opt = opt / counter
    return opt
Esempio n. 6
0
def Mosaic(pattern, *args):
    [nr, nc] = pattern
    [r, c] = args[0].shape
    grand = np.zeros([nr * r, nc * c], dtype=np.float32)

    cnt = 0
    for im in args:
        idx_r = int(np.floor(cnt / nc))
        idx_c = cnt % nc
        grand[idx_r * r:(idx_r + 1) * r,
              idx_c * c:(idx_c + 1) * c] = util.ImageRescale(im, [0, 255])
        cnt += 1
    return grand
Esempio n. 7
0
def RandomCropper(im, num):
    global msk
    h, w = im.shape
    lib = ([0, 0], [0, w - msk[1]], [h - msk[0], 0], [h - msk[0], w - msk[1]])

    # random points
    for i in range(num):
        pseed = [random.randint(0, h - msk[0]), random.randint(0, w - msk[1])]
        lib = lib + (pseed, )

    if not len(lib) == num + 4:
        raise ValueError('Length not matched.')

    # form a stack for model testing
    stack = np.zeros([len(lib), msk[0], msk[1]], dtype=np.float32)
    for i in range(len(lib)):
        x, y = lib[i]
        stack[i, :, :] = util.ImageRescale(im[x:x + msk[0], y:y + msk[1]],
                                           [0, 255])

    return stack, lib
Esempio n. 8
0
#%% Test
for step,(x_seq,y_seq) in enumerate(test_loader):
    with torch.no_grad():
        h_ = []
        c_ = []
        B, n_seq, _, H, W = x_seq.size()
        h_, c_ = UNet_LSTM.state_init(enc_nch_1,B,H,W,device)
        
        for i in range(n_seq):
            x1 = Variable(x_seq[:,i,:,:,:]).to(device)    
            y_pred_1, h_, c_ = model1(x1, h_, c_)
        
        x2 = Variable(torch.squeeze(x_seq,dim=1)).to(device)
        y_pred_2 = model2(x2)
        
        im_x = util.ImageRescale(x2[0,-1,:,:500].detach().cpu().numpy(),[0,255])
        im_y = util.ImageRescale(y_seq[0,0,:,:500].numpy(),[0,255])
        im_1 = util.ImageRescale(y_pred_1[0,0,:,:500].detach().cpu().numpy(),[0,255])
        im_2 = util.ImageRescale(y_pred_2[0,0,:,:500].detach().cpu().numpy(),[0,255])
        
        bot = np.concatenate((im_x,im_y),axis=1)
        top = np.concatenate((im_1,im_2),axis=1)

        if step % 100 == 0:
            plt.figure(figsize=(16,16))
            plt.axis('off')
            plt.title('Slice: {}'.format(step),fontsize=20)
            plt.imshow(np.concatenate((top,bot),axis=0),cmap='gray')
            plt.show()
            y = Variable(y_seq[:, i, :, :]).to(device)
            y_pred, h_, c_ = model(x, h_, c_)

            loss = alpha * criterion1(y_pred, y) + beta * criterion2(y_pred, y)
            sum_loss += loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if step % 200 == 0:
            print('[%d/%d][%d/%d]\tLoss: %.4f' %
                  (epoch, n_epoch, step, len(train_loader), loss.item()))

        if step == len(train_loader) - 1:
            pred = util.ImageRescale(
                y_pred[0, 0, :, :500].detach().cpu().numpy(), [0, 255])
            im_y = util.ImageRescale(y[0, 0, :, :500].detach().cpu().numpy(),
                                     [0, 255])
            im_x = util.ImageRescale(x[0, 0, :, :500].detach().cpu().numpy(),
                                     [0, 255])

            plt.figure(figsize=(18, 6))
            plt.axis('off')
            plt.title('Epoch: {}'.format(epoch + 1), fontsize=15)
            plt.imshow(np.concatenate((im_x, pred, im_y), axis=1), cmap='gray')
            plt.show()

t2 = time.time()
print('Training finished. Time: {} min'.format((t2 - t1) / 60))
Esempio n. 10
0
file = 'fish_test.csv'

df = pd.read_csv(root + file)
row, col = df.shape
opt = ()

for i in range(row):
    # index 0 in col indicate ID of the fish
    fish_root = dataroot + np.str(df.loc[i][0]) + '\\'
    print('fish: {}, samples: {}'.format(df.loc[i][0], len(opt)))
    # get session number
    for session in os.listdir(fish_root):
        session_root = fish_root + session + '\\'
        j = np.int(session[-1])

        for item in os.listdir(session_root):
            if item.endswith('L.nii.gz'):
                vol_L = util.nii_loader(session_root + item)
            elif item.endswith('R.nii.gz'):
                vol_R = util.nii_loader(session_root + item)

        if type(df.loc[i][j]) == str:
            lib = df.loc[i][j]
            rng_left, rng_right = string2idx(lib)

        vol = np.float32(vol_L[10:-10, rng_left[0]:rng_left[1], 210:690])
        vol = util.ImageRescale(vol, [0, 255])

        util.nii_saver(vol, dataroot + 'test_data\\',
                       'v30s{}.nii.gz'.format(j))
Esempio n. 11
0
        x_tensor, y_tensor = self.ToTensor(x,y)
        return x_tensor, y_tensor

test_loader = Data.DataLoader(dataset=dn_test_loader(dataroot),
                               batch_size=1, shuffle=False)
   
#%%
vae_train_data = ()
for step,(tensor_x,tensor_y) in enumerate(test_loader):
    with torch.no_grad():
        x = Variable(tensor_x).to(device)
        y = Variable(tensor_y).to(device)
        
        dn_x = model_x(x)
        dn_y = model_y(y)
        
        im_x = util.ImageRescale(dn_x[0,0,:,:].detach().cpu().numpy(),[0,255])
        im_y = util.ImageRescale(dn_y[0,0,:,:].detach().cpu().numpy(),[0,255])
        
        vae_train_data = vae_train_data + ((im_x,im_y),)
        
        if step % 500 == 0 and step != 0:
            plt.figure(figsize=(12,6))
            plt.axis('off')
            plt.imshow(np.concatenate((im_x,im_y),axis=1),cmap='gray')
            plt.show()

with open('E:\\OCTA\\data\\VAE_train_data.pickle','wb') as func:
    pickle.dump(vae_train_data,func)
        
Esempio n. 12
0
Nsf = NetworkArch.MS_UNet(gpu, nch_1).to(device)
Nsf.load_state_dict(torch.load(modelroot + 'noi2sf_MSUNet.pt'))

Ndn = NetworkArch.MS_UNet(gpu, nch_2).to(device)
Ndn.load_state_dict(torch.load(modelroot + '1_12.pt'))

#%% load the raw data

if __name__ == '__main__':

    root = 'E:\\human\\'
    volumelist = []

    for file in os.listdir(root):
        if file.startswith('Retina2_Fovea') and file.endswith('.tif'):
            volumelist.append(file)
    volumelist.sort()

    volume = volumelist[0]
    raw = util.ImageRescale(io.imread(root + volume), [0, 255])

    FrameNum = 5
    nch = 7

    t1 = time.time()
    V = PMFN_main(raw, FrameNum, nch, Nsf, Ndn, False, False)
    t2 = time.time()
    print('Denoised finish, time used: {} min'.format((t2 - t1) / 60))

    util.nii_saver(V, 'E:\\Retina2_Fovea\\101_1\\', 'MSUN_101.nii.gz')
Esempio n. 13
0
            optimizer.step()
        else:
            Seg_opt.zero_grad()
            Syn_opt.zero_grad()
            
            loss.backward()
            
            Seg_opt.step()
            Syn_opt.step()

        if step % 250 == 0:
            print('[%d/%d][%d/%d][L1:%.4f | L2:%.4f]'%(epoch,n_epoch,
                  step,len(train_loader),l1.item(),l2.item()))
        
        if step % 1000 == 0 and step != 0:
            seg = util.ImageRescale(y_seg[0,0,:,:].detach().cpu().numpy(),[0,255])
            syn = util.ImageRescale(y_syn[0,0,:,:].detach().cpu().numpy(),[0,255])
            im_x = util.ImageRescale(x[0,0,:,:].detach().cpu().numpy(),[0,255])
            im_y = util.ImageRescale(y[0,0,:,:].detach().cpu().numpy(),[0,255])
            
            top = np.concatenate((im_x,im_y),axis=1)
            bot = np.concatenate((seg,syn),axis=1)
            
            plt.figure(figsize=(12,12))
            plt.axis('off')
            plt.title('Epoch: {}'.format(epoch),fontsize=15)
            plt.imshow(np.concatenate((top,bot),axis=0),cmap='gray')
            plt.show()
            
    epoch_loss.append(sum_loss)
    
Esempio n. 14
0
r_var = 3
r_proj = 7

for file in os.listdir(dataroot):
    if file.startswith('AR') and file.endswith('.nii.gz'):
        vol = util.nii_loader(dataroot + file)
        h, slc, w = vol.shape
        # local_proj{orig}
        orig_proj = loc_Proc(vol, 3, 3, slc - 3, 'mean')
        # var
        var = loc_Proc(vol, r_var, r_var, slc - r_var, 'var')
        # local_proj{var}
        _, slc, _ = var.shape
        var_proj = loc_Proc(var, r_proj, r_proj, slc - r_proj, 'mean')

        vol_vp, vol_v = cutter(var_proj, var)
        _, vol_op = cutter(var_proj, orig_proj)
        _, vol_orig = cutter(var_proj, vol)

        for i in range(w):
            vol_op[:, :, i] = ContrastEnhance(vol_op[:, :, i])

        idx = file.find('_')
        util.nii_saver(vol_orig, saveroot, 'orig{}'.format(file[idx:]))
        util.nii_saver(util.ImageRescale(vol_op, [0, 255]), saveroot,
                       'proj(orig){}'.format(file[idx:]))
        util.nii_saver(util.ImageRescale(vol_v, [0, 255]), saveroot,
                       'var{}'.format(file[idx:]))
        util.nii_saver(util.ImageRescale(vol_vp, [0, 255]), saveroot,
                       'proj(var){}'.format(file[idx:]))
Esempio n. 15
0
 def __getitem__(self, idx):
     (x, y1, y2) = self.pair[idx]
     # only load the original noisy image channel
     x_tensor, y1_tensor, y2_tensor = self.ToTensor(
         x, util.ImageRescale(y1, [0, 1]), util.ImageRescale(y2, [0, 1]))
     return x_tensor, y1_tensor, y2_tensor
Esempio n. 16
0
def PMFN_main(raw_Volume, FrameNum, nch, Nsf, Ndn, sf_display, dn_display):
    # Pre-define the output form
    depth = 500 - nch + 1
    opt = np.zeros([FrameNum, depth, 512, 500], dtype=np.float32)

    for idx in range(FrameNum):

        # [1] Frame separation and croping
        volume = PickFrame(raw, FrameNum, idx)
        volume = volume[:, 150:662, :]
        print('Frame {} separated.'.format(idx))

        # [2] Bscan Registration
        print('Registering Bscans...')
        test_x = BscanRegist(volume, nch)

        # [3] noise to self-fusion-y
        print('Predicting self-fusion-y...')
        denoise_x = []
        sf_loader = Data.DataLoader(dataset=MyDataset(test_x),
                                    batch_size=1,
                                    shuffle=False)

        for step, x in enumerate(sf_loader):
            with torch.no_grad():
                x = Variable(x).to(device)
                pred = Nsf(x).detach().cpu().numpy()

                x_opt = np.zeros([3, 512, 512], dtype=np.float32)

                bscans = x.detach().cpu().numpy()
                x_opt[0, :, :] = bscans[0, 1, :, :]

                pred = util.ImageRescale(pred[0, 0, :, :500], [0, 255])
                gradient = Sobel(pred, 3)

                x_opt[1, :, :500] = gradient
                x_opt[2, :, :500] = pred
                denoise_x.append(x_opt)

                if step % 20 == 0 and sf_display == True:
                    plt.figure(figsize=(18, 6))
                    plt.title('slc:{}'.format(step))
                    plt.axis('off')
                    plt.imshow(np.concatenate(
                        [x_opt[0, :, :500], pred, gradient], axis=1),
                               cmap='gray')
                    plt.show()
        del sf_loader

        # [4] multi-modal input to denoised
        print('De-speckling...')
        dn_loader = Data.DataLoader(dataset=MyDataset(denoise_x),
                                    batch_size=1,
                                    shuffle=False)

        for step, x in enumerate(dn_loader):
            with torch.no_grad():

                x = Variable(x).to(device)
                noi = x[0, 0, :, :500].detach().cpu().numpy()
                pred = Ndn(x).detach().cpu().numpy()
                opt[idx,
                    step, :, :] = util.ImageRescale(pred[0, 0, :, :500],
                                                    [0, 255])

                if step % 20 == 0 and dn_display == True:
                    plt.figure(figsize=(18, 6))
                    plt.title('slc:{}'.format(step))
                    plt.axis('off')
                    plt.imshow(np.concatenate([noi, opt[idx, step, :, :]],
                                              axis=1),
                               cmap='gray')
                    plt.show()
        del dn_loader
        print('------------------------------------------------')

    return opt
Esempio n. 17
0
 def __getitem__(self,idx):
     x, y = self.pair[idx]
     x_tensor, y_tensor = self.ToTensor(x,util.ImageRescale(y,[0,1]))
     return x_tensor, y_tensor
Esempio n. 18
0
     test_loader = Data.DataLoader(dataset=test(dataroot+file),batch_size=1,
                                   shuffle=False)
     print('volume {} loaded.'.format(file))
 
 vol_latent = np.zeros([480,len(test_loader),480],dtype=np.float32)
 vol_syn = np.zeros([480,len(test_loader),480],dtype=np.float32)
 
 # test
 for step, tensor_x in enumerate(test_loader):
     x = Variable(tensor_x).to(device)
     # take the synthesized image as denoised x
     _,dn_x = model1(x)
     latent,syn = model2(dn_x)
     
     im_x = x[0,0,:,:].detach().cpu().numpy()
     im_x = util.ImageRescale(im_x,[0,255])
     
     # latent image
     im_latent = latent[0,0,:,:].detach().cpu().numpy()
     im_latent = util.ImageRescale(-im_latent,[0,255])
     
     # synthetic image
     im_syn = syn[0,0,:,:].detach().cpu().numpy()
     im_syn = util.ImageRescale(im_syn,[0,255])
     
     vol_latent[:,step,:] = im_latent
     vol_syn[:,step,:] = im_syn
     
     if step % 5 == 0:
         plt.figure(figsize=(15,5))
         plt.imshow(np.concatenate((im_x,im_latent,im_syn),axis=1),cmap='gray')
Esempio n. 19
0
        Nsf_test_loader = Data.DataLoader(dataset=Nsf_test_dataset(dataroot +
                                                                   hn_list[i]))
        print('dataloader {} created'.format(hn_list[i]))

        for step, x in enumerate(Nsf_test_loader):
            with torch.no_grad():
                x = Variable(x).to(device)
                pred = Nsf(x).detach().cpu().numpy()

                # high noise bscan -- Nsf(x) -- sobel
                x_stack = np.zeros([3, 512, 512], dtype=np.float32)

                bscans = x.detach().cpu().numpy()
                x_stack[0, :, :] = bscans[0, 1, :, :]
                x_stack[1, :, :500] = util.ImageRescale(
                    pred[0, 0, :, :500], [0, 255])
                x_stack[2, :, :500] = Sobel(x_stack[1, :, :500], 3)

                y1 = np.zeros([512, 512], dtype=np.float32)
                y1[:, :500] = ln[step, :, :500]

                y2 = np.zeros([512, 512], dtype=np.float32)
                y2[:, :500] = sf_ln[step, :, :500]

                pair_data = pair_data + ((x_stack, y1, y2), )

            if step == 200:
                plt.figure(figsize=(10, 5))
                plt.axis('off')
                plt.imshow(np.concatenate((x_stack[0, :, :], x_stack[1, :, :]),
                                          axis=1),
Esempio n. 20
0
        vlist.append(file)

global nFrame, fixFrame, sf_r
nFrame = 5
fixFrame = 0
sf_r = 3

saveroot = 'E:\\real-time-PMFN\\data\\'

t1 = time.time()

for i in range(len(vlist)):
    print('Creating low noise volume....')
    name = vlist[i]
    # rescaling -> crop-out the massive back ground -> reshape
    HN = util.ImageRescale(io.imread(dataroot + name), [0, 255])
    HN = Re_Arrange(HN[:, :512, :])
    # Frame-average
    LN = FrameAver(HN, fixFrame)

    util.nii_saver(HN[fixFrame, :, :, :], saveroot,
                   'HN_{}.nii.gz'.format(name[8:-4]))
    util.nii_saver(LN, saveroot, 'LN_{}.nii.gz'.format(name[8:-4]))

    # display a sample
    plt.figure(figsize=(10, 5))
    plt.axis('off'), plt.title('Volume #{}:'.format(i + 1) + name, fontsize=15)
    plt.imshow(np.concatenate((HN[0, 100, :, :], LN[100, :, :]), axis=1),
               cmap='gray')
    plt.show()
Esempio n. 21
0
dataroot = '/home/dewei/Desktop/octa/data/'
temp = '/home/dewei/Desktop/octa/temp/'
volume = ["vol_octa"]
radius = 3

for i in range(len(volume)):
    vol = util.nii_loader(dataroot + volume[i] + '.nii.gz')
    vol = np.transpose(vol, [1, 0, 2])
    h, slc, w = vol.shape
    n_slc = slc - 2 * radius

    # define the output volume
    vol_reg = np.zeros([h, n_slc, w], dtype=np.float32)
    vol_non_reg = np.zeros([h, n_slc, w], dtype=np.float32)

    for j in range(radius, slc - radius):
        stack = vol[:, j - radius:j + radius + 1, :]
        stack_rg = tool.greedy(stack, temp)

        vol_non_reg[:,
                    j - radius, :] = util.ImageRescale(tool.sf(stack, temp),
                                                       [0, 255])
        vol_reg[:, j - radius, :] = util.ImageRescale(tool.sf(stack_rg, temp),
                                                      [0, 255])

    # save the volume
    util.nii_saver(vol_non_reg, dataroot, 'sf_' + volume[i] + '.nii.gz')
    util.nii_saver(vol_reg, dataroot, 'sf_reg_' + volume[i] + '.nii.gz')
    print('volume {} self-fused.'.format(volume[i]))

print('Execution finished.')
Esempio n. 22
0
        loss.backward()
        optimizer.step()

        if step % 100 == 0:
            print('[%d/%d][%d/%d]\tLoss:%.4f' %
                  (epoch, nepoch, step, len(sf_loader), loss))

        if step % 500 == 0:
            with torch.no_grad():

                im_grand = np.zeros([512, 1000], dtype=np.float32)

                denoise = pred.detach().cpu().numpy()
                avg = train_y.detach().cpu().numpy()

                im_grand[:, :500] = util.ImageRescale(denoise[0, 0, :, :500],
                                                      [0, 255])
                im_grand[:,
                         500:1000] = util.ImageRescale(avg[0, 0, :, :500],
                                                       [0, 255])

                plt.figure(figsize=(12, 6))
                plt.axis('off')
                plt.imshow(im_grand, cmap='gray')
                plt.show()

    scheduler.step()

#%%
name = 'HN2SF_nch=3.pt'
torch.save(Nsf.state_dict(), modelroot + name)