def objectiveFunction(x, N, N_im, dims, dimOpt, dimLenOpt, lam1, lam2, data, k, strtag, ph, kern,
                      dirWeight=0, dirs=None, dirInfo=[None,None,None,None], nmins=0, wavelet='db4', mode="per", a=10.):
    '''
    This is the optimization function that we're trying to optimize. We are optimizing x here, and testing it within the funcitons that we want, as called by the functions that we've created
    '''
    #dirInfo[0] is M
    #import pdb; pdb.set_trace()
    tv = 0
    xfm = 0
    data.shape = N_im
    x.shape = N
    if len(N) > 2:
        x0 = np.zeros(N_im)
        for i in xrange(N[0]):
            x0[i,:,:] = tf.iwt(x[i,:,:],wavelet,mode,dims,dimOpt,dimLenOpt)
    else:
        x0 = tf.iwt(x,wavelet,mode,dims,dimOpt,dimLenOpt)
    
    obj = np.sum(objectiveFunctionDataCons(x0,N_im,ph,data,k))
    
    if lam1 > 1e-6:
        tv = np.sum(objectiveFunctionTV(x0,N_im,strtag,kern,dirWeight,dirs,nmins,dirInfo=dirInfo,a=a))
    
    if lam2 > 1e-6:
        xfm = np.sum((1/a)*np.log(np.cosh(a*x)))
    
    x.shape = (x.size,) # Not the most efficient way to do this, but we need the shape to reset.
    data.shape = (data.size,)
    ##output
    #print('obj: %.2f' % (obj))
    #print('tv: %.2f' % (lam1*tv))
    #print('xfm: %.2f' % (lam2*xfm))
    return abs(obj + lam1*tv + lam2*xfm)
def derivativeFunction(x, N, N_im, dims, dimOpt, dimLenOpt, lam1, lam2, data, k, strtag, ph, 
                       kern, dirWeight=0.1, dirs=None, dirInfo=[None,None,None,None], nmins=0, wavelet="db4", mode="per", a=10.):
    '''
    This is the function that we're going to be optimizing via the scipy optimization pack. This is the function that represents Compressed Sensing
    '''
    #import pdb; pdb.set_trace()
    disp = 0
    gTV = 0
    gXFM = 0
    x.shape = N
    if len(N) > 2:
        x0 = np.zeros(N_im)
        for i in xrange(N[0]):
            x0[i,:,:] = tf.iwt(x[i,:,:],wavelet,mode,dims,dimOpt,dimLenOpt)
    else:
        x0 = tf.iwt(x,wavelet,mode,dims,dimOpt,dimLenOpt)
    
    gDataCons = tf.wt(grads.gDataCons(x0,N_im,ph,data,k),wavelet,mode,dims,dimOpt,dimLenOpt)[0]
    if lam1 > 1e-6:
        gTV = tf.wt(grads.gTV(x0,N_im,strtag,kern,dirWeight,dirs,nmins,dirInfo=dirInfo,a=a),wavelet,mode,dims,dimOpt,dimLenOpt)[0] # Calculate the TV gradient
    if lam2 > 1e-6:
        gXFM = grads.gXFM(x,a=a)
    
    x.shape = (x.size,)
    
    return (gDataCons + lam1*gTV + lam2*gXFM).flatten() # Export the flattened array
Exemplo n.º 3
0
def derivativeFunction(x,
                       N,
                       N_im,
                       dims,
                       dimOpt,
                       dimLenOpt,
                       lam1,
                       lam2,
                       data,
                       k,
                       strtag,
                       ph,
                       kern,
                       dirWeight=0.1,
                       dirs=None,
                       dirInfo=[None, None, None, None],
                       nmins=0,
                       wavelet="db4",
                       mode="per",
                       a=10.):
    '''
    This is the function that we're going to be optimizing via the scipy optimization pack. This is the function that represents Compressed Sensing
    '''
    #import pdb; pdb.set_trace()
    disp = 0
    gTV = 0
    gXFM = 0
    x.shape = N
    if len(N) > 2:
        x0 = np.zeros(N_im)
        for i in xrange(N[0]):
            x0[i, :, :] = tf.iwt(x[i, :, :], wavelet, mode, dims, dimOpt,
                                 dimLenOpt)
    else:
        x0 = tf.iwt(x, wavelet, mode, dims, dimOpt, dimLenOpt)

    gDataCons = tf.wt(grads.gDataCons(x0, N_im, ph, data, k), wavelet, mode,
                      dims, dimOpt, dimLenOpt)[0]
    if lam1 > 1e-6:
        gTV = tf.wt(
            grads.gTV(x0,
                      N_im,
                      strtag,
                      kern,
                      dirWeight,
                      dirs,
                      nmins,
                      dirInfo=dirInfo,
                      a=a), wavelet, mode, dims, dimOpt,
            dimLenOpt)[0]  # Calculate the TV gradient
    if lam2 > 1e-6:
        gXFM = grads.gXFM(x, a=a)

    x.shape = (x.size, )

    return (gDataCons + lam1 * gTV +
            lam2 * gXFM).flatten()  # Export the flattened array
Exemplo n.º 4
0
def f(x, N, N_im, dims, dimOpt, dimLenOpt, lam1, lam2, data, k, strtag, ph,     
                      kern, dirWeight=0, dirs=None, dirInfo=[None,None,None,None], nmins=0, wavelet='db4', mode="per", level=3, a=10., mask=None, kmask=None):
    '''
    This is the optimization function that we're trying to optimize. We are optimizing x here, and testing it within the funcitons that we want, as called by the functions that we've created
    '''
    #dirInfo[0] is M
    tv = 0
    xfm = 0
    data.shape = N_im
    x.shape = N
    if len(N) > 2:
        x0 = np.zeros(N_im)
        for i in xrange(N[0]):
            x0[i,:,:] = tf.iwt(x[i,:,:],wavelet,mode,dims,dimOpt,dimLenOpt)
    else:
        x0 = tf.iwt(x,wavelet,mode,dims,dimOpt,dimLenOpt)
    
    obj = np.sum(objectiveFunctionDataCons(x0,N_im,ph,data,k,kmask=kmask))
    
    if lam1 > 1e-6:
        tv = np.sum(objectiveFunctionTV(x0,N_im,strtag,kern,dirWeight,dirs,nmins,dirInfo=dirInfo,a=a))
    
    if lam2 > 1e-6:
        xfm = (1/a)*np.log(np.cosh(a*x))
        locs=np.where(np.isinf(xfm))
        xfm[locs]=abs(x[locs])
        xfm = np.sum(xfm)
        
    x = np.ascontiguousarray(x)
    x.shape = (x.size,) # Not the most efficient way to do this, but we need the shape to reset.
    data = np.ascontiguousarray(data)
    data.shape = (data.size,)
    #import pdb; pdb.set_trace()
    ##output
    #print('obj: %.2f' % (obj))
    #print('tv: %.2f' % (lam1*tv))
    #print('xfm: %.2f' % (lam2*xfm))
    if np.any(np.isinf(x)) or np.any(np.isnan(x)):
        import pdb; pdb.set_trace()
    return obj + lam1*tv + lam2*xfm
Exemplo n.º 5
0
def df(x, N, N_im, dims, dimOpt, dimLenOpt, lam1, lam2, data, k, strtag, ph, 
                       kern, dirWeight=0.1, dirs=None, dirInfo=[None,None,None,None], nmins=0, wavelet="db4", mode="per", level=3, a=10., mask=None, kmask=None):
    '''
    This is the function that we're going to be optimizing via the scipy optimization pack. This is the function that represents Compressed Sensing
    '''
    disp = 0
    gTV = 0
    gXFM = 0
    x.shape = N
    if len(N) > 2:
        x0 = np.zeros(N_im)
        for i in xrange(N[0]):
            x0[i,:,:] = tf.iwt(x[i,:,:],wavelet,mode,dims,dimOpt,dimLenOpt)
    else:
        x0 = tf.iwt(x,wavelet,mode,dims,dimOpt,dimLenOpt)
    
    x0 = x0*mask[np.newaxis,:,:]
    gdc = grads.gDataCons(x0,N_im,ph,data,k,kmask=kmask)
    #import pdb; pdb.set_trace()
    if lam1 > 1e-6:
        gtv = grads.gTV(x0,N_im,strtag,kern,dirWeight,dirs,nmins,dirInfo=dirInfo,a=a)
    
    gDataCons = np.zeros(N)
    gTV = np.zeros(N)
    gXFM = np.zeros(N)
    
    for i in xrange(N[0]):
        gDataCons[i,:,:] = tf.wt(gdc[i,:,:],wavelet,mode,level,dims,dimOpt,dimLenOpt,mask)[0]
        if lam1 > 1e-6:
            gTV[i,:,:] = tf.wt(gtv[i,:,:],wavelet,mode,level,dims,dimOpt,dimLenOpt,mask)[0] # Calculate the TV gradient
        if lam2 > 1e-6:
            gXFM[i,:,:] = grads.gXFM(x[i,:,:],a=a)
    
    #import pdb; pdb.set_trace()
    x.shape = (x.size,)
    
    return (gDataCons + lam1*gTV + lam2*gXFM).flatten() # Export the flattened array
        w_result = opt.minimize(f, w_dc, args=args, method=method, jac=df, 
                                    options={'maxiter': ItnLim, 'lineSearchItnLim': lineSearchItnLim, 'gtol': 0.01, 'disp': 1, 'alpha_0': alpha_0, 'c': c, 'xtol': xtol[i], 'TVWeight': TV[i], 'XFMWeight': XFM[i], 'N': N})
        if np.any(np.isnan(w_result['x'])):
            print('Some nan''s found. Dropping TV and XFM values')
        elif w_result['status'] != 0:
            print('TV and XFM values too high -- no solution found. Dropping...')
        else:
            w_dc = w_result['x']
            stps.append(w_dc)
            tvStps.append(TV[i])
            
            
    w_res = w_dc.reshape(N)
    im_res = np.zeros(N_im)
    for i in xrange(N[0]):
        im_res[i,:,:] = tf.iwt(w_res[i,:,:],wavelet,mode,dims,dimOpt,dimLenOpt)
    ims.append(im_res)
    
im_stps = np.zeros([len(stps), N_im[-2], N_im[-1]])
gtv = np.zeros([len(stps), N_im[-2], N_im[-1]])
gxfm = np.zeros([len(stps), N_im[-2], N_im[-1]])
gdc = np.zeros([len(stps), N_im[-2], N_im[-1]])
for jj in range(len(stps)):
    im_stps[jj,:,:] = tf.iwt(stps[jj].reshape(N[-2:]),wavelet,mode,dims,dimOpt,dimLenOpt)
    gtv[jj,:,:] = grads.gTV(im_stps[jj,:,:].reshape(N_im),N_im,strtag, kern, 0, a=a)
    gxfm[jj,:,:] = tf.iwt(grads.gXFM(stps[jj].reshape(N[-2:]),a=a),wavelet,mode,dims,dimOpt,dimLenOpt)
    gdc[jj,:,:] = grads.gDataCons(im_stps[jj,:,:], N_im, ph_scan, data, k)
    

    
for i in xrange(len(stps)):
def runCSAlgorithm(fromfid=False,
                   filename='/home/asalerno/Documents/pyDirectionCompSense/brainData/P14/data/fullySampledBrain.npy',
                   sliceChoice=150,
                   strtag = ['','spatial', 'spatial'],
                   xtol = [1e-2, 1e-3, 5e-4, 5e-4],
                   TV = [0.01, 0.005, 0.002, 0.001],
                   XFM = [0.01,.005, 0.002, 0.001],
                   dirWeight=0,
                   pctg=0.25,
                   radius=0.2,
                   P=2,
                   pft=False,
                   ext=0.5,
                   wavelet='db4',
                   mode='per',
                   method='CG',
                   ItnLim=30,
                   lineSearchItnLim=30,
                   alpha_0=0.6,
                   c=0.6,
                   a=10.0,
                   kern = 
                   np.array([[[ 0.,  0.,  0.], 
                   [ 0.,  0.,  0.], 
                   [ 0.,  0.,  0.]],                
                  [[ 0.,  0.,  0.],
                  [ 0., -1.,  0.],
                  [ 0.,  1.,  0.]],
                  [[ 0.,  0.,  0.],
                  [ 0., -1.,  1.],
                  [ 0.,  0.,  0.]]]),
                   dirFile = None,
                   nmins = None,
                   dirs = None,
                   M = None,
                   dirInfo = [None]*4,
                   saveNpy=False,
                   saveNpyFile=None,
                   saveImsPng=False,
                   saveImsPngFile=None,
                   saveImDiffPng=False,
                   saveImDiffPngFile=None,
                   disp=False):
    ##import pdb; pdb.set_trace()
    if fromfid==True:
        inputdirectory=filename[0]
        petable=filename[1]
        fullImData = rff.getDataFromFID(petable,inputdirectory,2)[0,:,:,:]
        fullImData = fullImData/np.max(abs(fullImData))
        im = fullImData[:,:,sliceChoice]
    else:
        im = np.load(filename)[sliceChoice,:,:]
        
    N = np.array(im.shape)  # image Size

    pdf = samp.genPDF(N[-2:], P, pctg, radius=radius, cyl=np.hstack([1, N[-2:]]), style='mult', pft=pft, ext=ext)
    if pft:
        print('Partial Fourier sampling method used')
    k = samp.genSampling(pdf, 50, 2)[0].astype(int)
    if len(N) == 2:
        N = np.hstack([1, N])
        k = k.reshape(N)
        im = im.reshape(N)
    elif (len(N) == 3) and ('dir' not in strtag):
        k = k.reshape(np.hstack([1,N[-2:]])).repeat(N[0],0)

    ph_ones = np.ones(N[-2:], complex)
    ph_scan = np.zeros(N, complex)
    data = np.zeros(N,complex)
    im_scan = np.zeros(N,complex)
    for i in range(N[0]):
        k[i,:,:] = np.fft.fftshift(k[i,:,:])
        data[i,:,:] = k[i,:,:]*tf.fft2c(im[i,:,:], ph=ph_ones)

        # IMAGE from the "scanner data"
        im_scan_wph = tf.ifft2c(data[i,:,:], ph=ph_ones)
        ph_scan[i,:,:] = tf.matlab_style_gauss2D(im_scan_wph,shape=(5,5))
        ph_scan[i,:,:] = np.exp(1j*ph_scan[i,:,:])
        im_scan[i,:,:] = tf.ifft2c(data[i,:,:], ph=ph_scan[i,:,:])
        #im_lr = samp.loRes(im,pctg)
    
    # ------------------------------------------------------------------ #
    # A quick way to look at the PSF of the sampling pattern that we use #
    delta = np.zeros(N[-2:])
    delta[int(N[-2]/2),int(N[-1]/2)] = 1
    psf = tf.ifft2c(tf.fft2c(delta,ph_ones)*k,ph_ones)
    # ------------------------------------------------------------------ #


    ## ------------------------------------------------------------------ #
    ## -- Currently broken - Need to figure out what's happening here. -- #
    ## ------------------------------------------------------------------ #
    #if pft:
        #for i in xrange(N[0]):
            #dataHold = np.fft.fftshift(data[i,:,:])
            #kHold = np.fft.fftshift(k[i,:,:])
            #loc = 98
            #for ix in xrange(N[-2]):
                #for iy in xrange(loc,N[-1]):
                    #dataHold[-ix,-iy] = dataHold[ix,iy].conj()
                    #kHold[-ix,-iy] = kHold[ix,iy]
    ## ------------------------------------------------------------------ #
    
    pdfDiv = pdf.copy()
    pdfZeros = np.where(pdf==0)
    pdfDiv[pdfZeros] = 1
    #im_scan_imag = im_scan.imag
    #im_scan = im_scan.real

    N_im = N.copy()
    hld, dims, dimOpt, dimLenOpt = tf.wt(im_scan[0].real,wavelet,mode)
    N = np.hstack([N_im[0], hld.shape])

    w_scan = np.zeros(N)
    w_full = np.zeros(N)
    im_dc = np.zeros(N_im)
    w_dc = np.zeros(N)

    for i in xrange(N[0]):
        w_scan[i,:,:] = tf.wt(im_scan.real[i,:,:],wavelet,mode,dims,dimOpt,dimLenOpt)[0]
        w_full[i,:,:] = tf.wt(abs(im[i,:,:]),wavelet,mode,dims,dimOpt,dimLenOpt)[0]

        im_dc[i,:,:] = tf.ifft2c(data[i,:,:] / np.fft.ifftshift(pdfDiv), ph=ph_scan[i,:,:]).real.copy()
        w_dc[i,:,:] = tf.wt(im_dc,wavelet,mode,dims,dimOpt,dimLenOpt)[0]

    w_dc = w_dc.flatten()
    im_sp = im_dc.copy().reshape(N_im)
    minval = np.min(abs(im))
    maxval = np.max(abs(im))
    data = np.ascontiguousarray(data)

    imdcs = [im_dc,np.zeros(N_im),np.ones(N_im),np.random.randn(np.prod(N_im)).reshape(N_im)]
    imdcs[-1] = imdcs[-1] - np.min(imdcs[-1])
    imdcs[-1] = imdcs[-1]/np.max(abs(imdcs[-1]))
    mets = ['Density Corrected','Zeros','1/2''s','Gaussian Random Shift (0,1)']
    wdcs = []
    for i in range(len(imdcs)):
        wdcs.append(tf.wt(imdcs[i][0],wavelet,mode,dims,dimOpt,dimLenOpt)[0].reshape(N))

    ims = []
    #print('Starting the CS Algorithm')
    for kk in range(len(wdcs)):
        w_dc = wdcs[kk]
        print(mets[kk])
        for i in range(len(TV)):
            args = (N, N_im, dims, dimOpt, dimLenOpt, TV[i], XFM[i], data, k, strtag, ph_scan, kern, dirWeight, dirs, dirInfo, nmins, wavelet, mode, a)
            w_result = opt.minimize(f, w_dc, args=args, method=method, jac=df, 
                                        options={'maxiter': ItnLim, 'lineSearchItnLim': lineSearchItnLim, 'gtol': 0.01, 'disp': 1, 'alpha_0': alpha_0, 'c': c, 'xtol': xtol[i], 'TVWeight': TV[i], 'XFMWeight': XFM[i], 'N': N})
            if np.any(np.isnan(w_result['x'])):
                print('Some nan''s found. Dropping TV and XFM values')
            elif w_result['status'] != 0:
                print('TV and XFM values too high -- no solution found. Dropping...')
            else:
                w_dc = w_result['x']
                
        w_res = w_dc.reshape(N)
        im_res = np.zeros(N_im)
        for i in xrange(N[0]):
            im_res[i,:,:] = tf.iwt(w_res[i,:,:],wavelet,mode,dims,dimOpt,dimLenOpt)
        ims.append(im_res)
    
    if saveNpy:
        if saveNpyFile is None:
            np.save('./holdSave_im_res_' + str(int(pctg*100)) + 'p_all_SP',ims)
        else:
            np.save(saveNpyFile,ims)
    
    if saveImsPng:
        vis.figSubplots(ims,titles=mets,clims=(minval,maxval),colorbar=True)
        if not disp:
            if saveImsPngFile is None:
                saveFig.save('./holdSave_ims_' + str(int(pctg*100)) + 'p_all_SP')
            else:
                saveFig.save(saveImsPngFile)
    
    if saveImDiffPng:
        imdiffs, clims = vis.imDiff(ims)
        diffMets = ['DC-Zeros','DC-Ones','DC-Random','Zeros-Ones','Zeros-Random','Ones-Random']
        vis.figSubplots(imdiffs,titles=diffMets,clims=clims,colorbar=True)
        if not disp:
            if saveImDiffPngFile is None:
                saveFig.save('./holdSave_im_diffs_' + str(int(pctg*100)) + 'p_all_SP')
            else:
                saveFig.save(saveImDiffPngFile)
    
    if disp:
        plt.show()
im_dc = np.zeros(N_im)
w_dc = np.zeros(N)


strftime("%Y-%m-%d %H:%M:%S", localtime())

args = (N, N_im, np.prod(N_im), dims, dimOpt, dimLenOpt, TV[0], XFM[0], data, k, strtag, ph_scan, kern, dirWeight, dirs, dirInfo, nmins, wavelet, mode, a)

w_result = opt.fmin_tnc(f, w_dc.flat, fprime=df, args=args, accuracy=1e-4, disp=5)

wHold = w_result[0].copy().reshape(N)
imHold = np.zeros(N_im)

for i in xrange(N[0]):
    imHold[i,:,:] = tf.iwt(wHold[i,:,:],wavelet,mode,dims,dimOpt,dimLenOpt)

strftime("%Y-%m-%d %H:%M:%S", localtime())





##pdfDiv = pdf.copy()
##pdfZeros = np.where(pdf==0)
##pdfDiv[pdfZeros] = 1
##im_scan_imag = im_scan.imag
##im_scan = im_scan

#x, y = np.meshgrid(np.linspace(-1,1,N[-1]),np.linspace(-1,1,N[-2]))
#locs = (abs(x)<=radius) * (abs(y)<=radius)
Exemplo n.º 9
0
args = (N, N_im, np.prod(N_im), dims, dimOpt, dimLenOpt, TV[0], XFM[0], data,
        k, strtag, ph_scan, kern, dirWeight, dirs, dirInfo, nmins, wavelet,
        mode, a)

w_result = opt.fmin_tnc(f,
                        w_dc.flat,
                        fprime=df,
                        args=args,
                        accuracy=1e-4,
                        disp=5)

wHold = w_result[0].copy().reshape(N)
imHold = np.zeros(N_im)

for i in xrange(N[0]):
    imHold[i, :, :] = tf.iwt(wHold[i, :, :], wavelet, mode, dims, dimOpt,
                             dimLenOpt)

strftime("%Y-%m-%d %H:%M:%S", localtime())

##pdfDiv = pdf.copy()
##pdfZeros = np.where(pdf==0)
##pdfDiv[pdfZeros] = 1
##im_scan_imag = im_scan.imag
##im_scan = im_scan

#x, y = np.meshgrid(np.linspace(-1,1,N[-1]),np.linspace(-1,1,N[-2]))
#locs = (abs(x)<=radius) * (abs(y)<=radius)
#minLoc = np.min(np.where(locs==True))

#pctgSamp = np.zeros(minLoc+1)
#for i in range(1,minLoc+1):
Exemplo n.º 10
0
        #stps.append(w_dc)
        #w_stp.append(w_dc.reshape(NSub))
        #im_hld = np.zeros(N_imSub)
        #for i in range(NSub[0]):
        #im_hld[i] = tf.iwt(w_stp[-1][i],wavelet,mode,dimsSub,dimOptSub,dimLenOptSub)
        #imStp.append(im_hld)
        #plt.imshow(imStp[-1],clim=(minval,maxval)); plt.colorbar(); plt.show()
        #w_dc = w_stp[k].flatten()
        #stps.append(w_dc)
        #wdcHold = w_dc.reshape(NSub)
        #dataStp = np.fft.fftshift(tf.fft2c(imStp[-1],ph_scanSub),axes=(-2,-1))
        #kStp = np.fft.fftshift(kSub,axes=(-2,-1)).copy()
        #kMaskRpt = kMasked.reshape(np.hstack([1,N_imSub[-2:]])).repeat(N_imSub[0],0)
        im_hld = np.zeros(N_imSub)
        for i in range(NSub[0]):
            im_hld[i] = tf.iwt(w_dc[i], wavelet, mode, dimsSub, dimOptSub,
                               dimLenOptSub)
        data_dc = tf.fft2c(im_hld, ph=ph_scanSub, axes=(-2, -1))
        kMasked = (np.floor(1 - kMasked) * pctgSamp[locSteps[j]] * lam_trust +
                   kMasked)
        #kMasked = (np.floor(1-kMasked)*pctgSamp[locSteps[j]]*1.0 + kMasked).reshape(np.hstack([1, N_imSub[-2:]]))

    wHold = w_dc.copy().reshape(NSub)
    imHold = np.zeros(N_imSub)

    for i in xrange(N[0]):
        imHold[i, :, :] = tf.iwt(wHold[i, :, :], wavelet, mode, dimsSub,
                                 dimOptSub, dimLenOptSub)

    np.save(
        'temp/' + str(int(100 * pctg)) + '_slice_' + str(sliceChoice) +
        '_TV_' + str(TV) + '_XFM_' + str(XFM) + '_lamTrust_' +
Exemplo n.º 11
0
                                    'N': N
                                })
        if np.any(np.isnan(w_result['x'])):
            print('Some nan' 's found. Dropping TV and XFM values')
        elif w_result['status'] != 0:
            print(
                'TV and XFM values too high -- no solution found. Dropping...')
        else:
            w_dc = w_result['x']
            stps.append(w_dc)
            tvStps.append(TV[i])

    w_res = w_dc.reshape(N)
    im_res = np.zeros(N_im)
    for i in xrange(N[0]):
        im_res[i, :, :] = tf.iwt(w_res[i, :, :], wavelet, mode, dims, dimOpt,
                                 dimLenOpt)
    ims.append(im_res)

im_stps = np.zeros([len(stps), N_im[-2], N_im[-1]])
gtv = np.zeros([len(stps), N_im[-2], N_im[-1]])
gxfm = np.zeros([len(stps), N_im[-2], N_im[-1]])
gdc = np.zeros([len(stps), N_im[-2], N_im[-1]])
for jj in range(len(stps)):
    im_stps[jj, :, :] = tf.iwt(stps[jj].reshape(N[-2:]), wavelet, mode, dims,
                               dimOpt, dimLenOpt)
    gtv[jj, :, :] = grads.gTV(im_stps[jj, :, :].reshape(N_im),
                              N_im,
                              strtag,
                              kern,
                              0,
                              a=a)
Exemplo n.º 12
0
     #stps.append(w_dc)
     #wdcHold = w_dc.reshape(NSub)
     #kMasked = np.floor(1-kMasked)*pctgSamp[locSteps[j]] + kMasked
 #if j == len(TV):
     #print('No solution found on final run. Saving last spot.')
     #w_dc = w_result['x']
     ##import pdb; pdb.set_trace()
     #stps.append(w_dc)
     #tvStps.append(TV[i])
     #w_stp.append(w_dc.reshape(NSub))
     #imStp.append(tf.iwt(w_stp[-1][0],wavelet,mode,dimsSub,dimOptSub,dimLenOptSub))
     ##plt.imshow(imStp[-1]); plt.colorbar(); plt.show()
     ##w_dc = w_stp[k].flatten()
     #stps.append(w_dc)
     #wdcHold = w_dc.reshape(NSub)
     ##dataStp = np.fft.fftshift(tf.fft2c(imStp[-1],ph_scanSub))
     ##kStp = np.fft.fftshift(kSub).copy()
     #kMasked = np.floor(1-kMasked)*pctgSamp[locSteps[j]] + kMasked
 w_dc = w_result['x']
 #import pdb; pdb.set_trace()
 stps.append(w_dc)
 tvStps.append(TV[i])
 w_stp.append(w_dc.reshape(NSub))
 imStp.append(tf.iwt(w_stp[-1][0],wavelet,mode,dimsSub,dimOptSub,dimLenOptSub))
 #plt.imshow(imStp[-1]); plt.colorbar(); plt.show()
 #w_dc = w_stp[k].flatten()
 stps.append(w_dc)
 wdcHold = w_dc.reshape(NSub)
 kMasked = np.floor(1-kMasked)*pctgSamp[locSteps[j]] + kMasked    
 dataStp = np.fft.fftshift(tf.fft2c(imStp[-1],ph_scanSub))
 kStp = np.fft.fftshift(kSub).copy()
Exemplo n.º 13
0
def runCSAlgorithm(
        fromfid=False,
        filename='/home/asalerno/Documents/pyDirectionCompSense/brainData/P14/data/fullySampledBrain.npy',
        sliceChoice=150,
        strtag=['', 'spatial', 'spatial'],
        xtol=[1e-2, 1e-3, 5e-4, 5e-4],
        TV=[0.01, 0.005, 0.002, 0.001],
        XFM=[0.01, .005, 0.002, 0.001],
        dirWeight=0,
        pctg=0.25,
        radius=0.2,
        P=2,
        pft=False,
        ext=0.5,
        wavelet='db4',
        mode='per',
        method='CG',
        ItnLim=30,
        lineSearchItnLim=30,
        alpha_0=0.6,
        c=0.6,
        a=10.0,
        kern=np.array([[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
                       [[0., 0., 0.], [0., -1., 0.], [0., 1., 0.]],
                       [[0., 0., 0.], [0., -1., 1.], [0., 0., 0.]]]),
        dirFile=None,
        nmins=None,
        dirs=None,
        M=None,
        dirInfo=[None] * 4,
        saveNpy=False,
        saveNpyFile=None,
        saveImsPng=False,
        saveImsPngFile=None,
        saveImDiffPng=False,
        saveImDiffPngFile=None,
        disp=False):
    ##import pdb; pdb.set_trace()
    if fromfid == True:
        inputdirectory = filename[0]
        petable = filename[1]
        fullImData = rff.getDataFromFID(petable, inputdirectory, 2)[0, :, :, :]
        fullImData = fullImData / np.max(abs(fullImData))
        im = fullImData[:, :, sliceChoice]
    else:
        im = np.load(filename)[sliceChoice, :, :]

    N = np.array(im.shape)  # image Size

    pdf = samp.genPDF(N[-2:],
                      P,
                      pctg,
                      radius=radius,
                      cyl=np.hstack([1, N[-2:]]),
                      style='mult',
                      pft=pft,
                      ext=ext)
    if pft:
        print('Partial Fourier sampling method used')
    k = samp.genSampling(pdf, 50, 2)[0].astype(int)
    if len(N) == 2:
        N = np.hstack([1, N])
        k = k.reshape(N)
        im = im.reshape(N)
    elif (len(N) == 3) and ('dir' not in strtag):
        k = k.reshape(np.hstack([1, N[-2:]])).repeat(N[0], 0)

    ph_ones = np.ones(N[-2:], complex)
    ph_scan = np.zeros(N, complex)
    data = np.zeros(N, complex)
    im_scan = np.zeros(N, complex)
    for i in range(N[0]):
        k[i, :, :] = np.fft.fftshift(k[i, :, :])
        data[i, :, :] = k[i, :, :] * tf.fft2c(im[i, :, :], ph=ph_ones)

        # IMAGE from the "scanner data"

        im_scan_wph = tf.ifft2c(data[i, :, :], ph=ph_ones)
        ph_scan[i, :, :] = tf.matlab_style_gauss2D(im_scan_wph, shape=(5, 5))
        ph_scan[i, :, :] = np.exp(1j * ph_scan[i, :, :])
        im_scan[i, :, :] = tf.ifft2c(data[i, :, :], ph=ph_scan[i, :, :])
        #im_lr = samp.loRes(im,pctg)

    # ------------------------------------------------------------------ #
    # A quick way to look at the PSF of the sampling pattern that we use #
    delta = np.zeros(N[-2:])
    delta[int(N[-2] / 2), int(N[-1] / 2)] = 1
    psf = tf.ifft2c(tf.fft2c(delta, ph_ones) * k, ph_ones)
    # ------------------------------------------------------------------ #

    ## ------------------------------------------------------------------ #
    ## -- Currently broken - Need to figure out what's happening here. -- #
    ## ------------------------------------------------------------------ #
    #if pft:
    #for i in xrange(N[0]):
    #dataHold = np.fft.fftshift(data[i,:,:])
    #kHold = np.fft.fftshift(k[i,:,:])
    #loc = 98
    #for ix in xrange(N[-2]):
    #for iy in xrange(loc,N[-1]):
    #dataHold[-ix,-iy] = dataHold[ix,iy].conj()
    #kHold[-ix,-iy] = kHold[ix,iy]
    ## ------------------------------------------------------------------ #

    pdfDiv = pdf.copy()
    pdfZeros = np.where(pdf == 0)
    pdfDiv[pdfZeros] = 1
    #im_scan_imag = im_scan.imag
    #im_scan = im_scan.real

    N_im = N.copy()
    hld, dims, dimOpt, dimLenOpt = tf.wt(im_scan[0].real, wavelet, mode)
    N = np.hstack([N_im[0], hld.shape])

    w_scan = np.zeros(N)
    w_full = np.zeros(N)
    im_dc = np.zeros(N_im)
    w_dc = np.zeros(N)

    for i in xrange(N[0]):
        w_scan[i, :, :] = tf.wt(im_scan.real[i, :, :], wavelet, mode, dims,
                                dimOpt, dimLenOpt)[0]
        w_full[i, :, :] = tf.wt(abs(im[i, :, :]), wavelet, mode, dims, dimOpt,
                                dimLenOpt)[0]

        im_dc[i, :, :] = tf.ifft2c(data[i, :, :] / np.fft.ifftshift(pdfDiv),
                                   ph=ph_scan[i, :, :]).real.copy()
        w_dc[i, :, :] = tf.wt(im_dc, wavelet, mode, dims, dimOpt, dimLenOpt)[0]

    w_dc = w_dc.flatten()
    im_sp = im_dc.copy().reshape(N_im)
    minval = np.min(abs(im))
    maxval = np.max(abs(im))
    data = np.ascontiguousarray(data)

    imdcs = [
        im_dc,
        np.zeros(N_im),
        np.ones(N_im),
        np.random.randn(np.prod(N_im)).reshape(N_im)
    ]
    imdcs[-1] = imdcs[-1] - np.min(imdcs[-1])
    imdcs[-1] = imdcs[-1] / np.max(abs(imdcs[-1]))
    mets = [
        'Density Corrected', 'Zeros', '1/2'
        's', 'Gaussian Random Shift (0,1)'
    ]
    wdcs = []
    for i in range(len(imdcs)):
        wdcs.append(
            tf.wt(imdcs[i][0], wavelet, mode, dims, dimOpt,
                  dimLenOpt)[0].reshape(N))

    ims = []
    #print('Starting the CS Algorithm')
    for kk in range(len(wdcs)):
        w_dc = wdcs[kk]
        print(mets[kk])
        for i in range(len(TV)):
            args = (N, N_im, dims, dimOpt, dimLenOpt, TV[i], XFM[i], data, k,
                    strtag, ph_scan, kern, dirWeight, dirs, dirInfo, nmins,
                    wavelet, mode, a)
            w_result = opt.minimize(f,
                                    w_dc,
                                    args=args,
                                    method=method,
                                    jac=df,
                                    options={
                                        'maxiter': ItnLim,
                                        'lineSearchItnLim': lineSearchItnLim,
                                        'gtol': 0.01,
                                        'disp': 1,
                                        'alpha_0': alpha_0,
                                        'c': c,
                                        'xtol': xtol[i],
                                        'TVWeight': TV[i],
                                        'XFMWeight': XFM[i],
                                        'N': N
                                    })
            if np.any(np.isnan(w_result['x'])):
                print('Some nan' 's found. Dropping TV and XFM values')
            elif w_result['status'] != 0:
                print(
                    'TV and XFM values too high -- no solution found. Dropping...'
                )
            else:
                w_dc = w_result['x']

        w_res = w_dc.reshape(N)
        im_res = np.zeros(N_im)
        for i in xrange(N[0]):
            im_res[i, :, :] = tf.iwt(w_res[i, :, :], wavelet, mode, dims,
                                     dimOpt, dimLenOpt)
        ims.append(im_res)

    if saveNpy:
        if saveNpyFile is None:
            np.save('./holdSave_im_res_' + str(int(pctg * 100)) + 'p_all_SP',
                    ims)
        else:
            np.save(saveNpyFile, ims)

    if saveImsPng:
        vis.figSubplots(ims,
                        titles=mets,
                        clims=(minval, maxval),
                        colorbar=True)
        if not disp:
            if saveImsPngFile is None:
                saveFig.save('./holdSave_ims_' + str(int(pctg * 100)) +
                             'p_all_SP')
            else:
                saveFig.save(saveImsPngFile)

    if saveImDiffPng:
        imdiffs, clims = vis.imDiff(ims)
        diffMets = [
            'DC-Zeros', 'DC-Ones', 'DC-Random', 'Zeros-Ones', 'Zeros-Random',
            'Ones-Random'
        ]
        vis.figSubplots(imdiffs, titles=diffMets, clims=clims, colorbar=True)
        if not disp:
            if saveImDiffPngFile is None:
                saveFig.save('./holdSave_im_diffs_' + str(int(pctg * 100)) +
                             'p_all_SP')
            else:
                saveFig.save(saveImDiffPngFile)

    if disp:
        plt.show()
        #stps.append(w_dc)
        #w_stp.append(w_dc.reshape(NSub))
        #im_hld = np.zeros(N_imSub)
        #for i in range(NSub[0]):
            #im_hld[i] = tf.iwt(w_stp[-1][i],wavelet,mode,dimsSub,dimOptSub,dimLenOptSub)
        #imStp.append(im_hld)
        #plt.imshow(imStp[-1],clim=(minval,maxval)); plt.colorbar(); plt.show()
        #w_dc = w_stp[k].flatten()
        #stps.append(w_dc)
        #wdcHold = w_dc.reshape(NSub)
        #dataStp = np.fft.fftshift(tf.fft2c(imStp[-1],ph_scanSub),axes=(-2,-1))
        #kStp = np.fft.fftshift(kSub,axes=(-2,-1)).copy()
        #kMaskRpt = kMasked.reshape(np.hstack([1,N_imSub[-2:]])).repeat(N_imSub[0],0)
        im_hld = np.zeros(N_imSub)
        for i in range(NSub[0]):
            im_hld[i] = tf.iwt(w_dc[i],wavelet,mode,dimsSub,dimOptSub,dimLenOptSub)
        data_dc = tf.fft2c(im_hld, ph=ph_scanSub, axes=(-2,-1))
        kMasked = (np.floor(1-kMasked)*pctgSamp[locSteps[j]]*lam_trust + kMasked)
        #kMasked = (np.floor(1-kMasked)*pctgSamp[locSteps[j]]*1.0 + kMasked).reshape(np.hstack([1, N_imSub[-2:]]))
    
wHold = w_dc.copy().reshape(NSub)
imHold = np.zeros(N_imSub)

for i in xrange(N[0]):
    imHold[i,:,:] = tf.iwt(wHold[i,:,:],wavelet,mode,dimsSub,dimOptSub,dimLenOptSub)
    

np.save('/hpf/largeprojects/MICe/asalerno/pyDirectionCompSense/tests/fullBrainTests/' + str(int(100*pctg)) + '_3_spatial_TV_im_final' + str(int(nSteps)) + '_comb_ks.npy',imHold)
np.save('/hpf/largeprojects/MICe/asalerno/pyDirectionCompSense/tests/fullBrainTests/' + str(int(100*pctg)) + '_3_spatial_TV_im_final' + str(int(nSteps)) + '_comb_ks.npy',wHold)

#outvol = volumeFromData('/hpf/largeprojects/MICe/asalerno/pyDirectionCompSense/tests/fullBrainTests/' + str(int(100*pctg)) + '_3_spatial_TV_im_final_' + str(int(nSteps)) + '_comb_ks.mnc', imHold, dimnames=['xspace','yspace','zspace'], starts=(0, 0, 0), steps=(1, 1, 1), volumeType="uint")
Exemplo n.º 15
0
def objectiveFunction(x,
                      N,
                      N_im,
                      sz,
                      dims,
                      dimOpt,
                      dimLenOpt,
                      lam1,
                      lam2,
                      data,
                      k,
                      strtag,
                      ph,
                      kern,
                      dirWeight=0,
                      dirs=None,
                      dirInfo=[None, None, None, None],
                      nmins=0,
                      wavelet='db4',
                      mode="per",
                      a=10.):
    '''
    This is the optimization function that we're trying to optimize. We are optimizing x here, and testing it within the funcitons that we want, as called by the functions that we've created
    '''
    #dirInfo[0] is M
    tv = 0
    xfm = 0
    data.shape = N_im
    x.shape = N
    if len(N) > 2:
        x0 = np.zeros(N_im, complex)
        for i in xrange(N[0]):
            x0[i, :, :] = tf.iwt(x[i, :, :], wavelet, mode, dims, dimOpt,
                                 dimLenOpt)
    else:
        x0 = tf.iwt(x, wavelet, mode, dims, dimOpt, dimLenOpt)

    obj = np.sum(objectiveFunctionDataCons(x0, N_im, ph, data, k, sz,
                                           strtag)).real

    if lam1 > 1e-6:
        tv = np.sum(
            abs(
                objectiveFunctionTV(x0,
                                    N_im,
                                    strtag,
                                    kern,
                                    dirWeight,
                                    dirs,
                                    nmins,
                                    dirInfo=dirInfo,
                                    a=a)))

    if lam2 > 1e-6:
        xfm = np.sum(abs((1 / a) * np.log(np.cosh(a * x))))

    x.shape = (
        x.size,
    )  # Not the most efficient way to do this, but we need the shape to reset.
    data.shape = (data.size, )
    #import pdb; pdb.set_trace()
    ###output
    #print('obj: %.2f' % (obj))
    #print('tv: %.2f' % (lam1*tv))
    #print('xfm: %.2f' % (lam2*xfm))
    return abs(obj + lam1 * tv + lam2 * xfm)