def objectiveFunction(x, N, N_im, dims, dimOpt, dimLenOpt, lam1, lam2, data, k, strtag, ph, kern, dirWeight=0, dirs=None, dirInfo=[None,None,None,None], nmins=0, wavelet='db4', mode="per", a=10.): ''' This is the optimization function that we're trying to optimize. We are optimizing x here, and testing it within the funcitons that we want, as called by the functions that we've created ''' #dirInfo[0] is M #import pdb; pdb.set_trace() tv = 0 xfm = 0 data.shape = N_im x.shape = N if len(N) > 2: x0 = np.zeros(N_im) for i in xrange(N[0]): x0[i,:,:] = tf.iwt(x[i,:,:],wavelet,mode,dims,dimOpt,dimLenOpt) else: x0 = tf.iwt(x,wavelet,mode,dims,dimOpt,dimLenOpt) obj = np.sum(objectiveFunctionDataCons(x0,N_im,ph,data,k)) if lam1 > 1e-6: tv = np.sum(objectiveFunctionTV(x0,N_im,strtag,kern,dirWeight,dirs,nmins,dirInfo=dirInfo,a=a)) if lam2 > 1e-6: xfm = np.sum((1/a)*np.log(np.cosh(a*x))) x.shape = (x.size,) # Not the most efficient way to do this, but we need the shape to reset. data.shape = (data.size,) ##output #print('obj: %.2f' % (obj)) #print('tv: %.2f' % (lam1*tv)) #print('xfm: %.2f' % (lam2*xfm)) return abs(obj + lam1*tv + lam2*xfm)
def derivativeFunction(x, N, N_im, dims, dimOpt, dimLenOpt, lam1, lam2, data, k, strtag, ph, kern, dirWeight=0.1, dirs=None, dirInfo=[None,None,None,None], nmins=0, wavelet="db4", mode="per", a=10.): ''' This is the function that we're going to be optimizing via the scipy optimization pack. This is the function that represents Compressed Sensing ''' #import pdb; pdb.set_trace() disp = 0 gTV = 0 gXFM = 0 x.shape = N if len(N) > 2: x0 = np.zeros(N_im) for i in xrange(N[0]): x0[i,:,:] = tf.iwt(x[i,:,:],wavelet,mode,dims,dimOpt,dimLenOpt) else: x0 = tf.iwt(x,wavelet,mode,dims,dimOpt,dimLenOpt) gDataCons = tf.wt(grads.gDataCons(x0,N_im,ph,data,k),wavelet,mode,dims,dimOpt,dimLenOpt)[0] if lam1 > 1e-6: gTV = tf.wt(grads.gTV(x0,N_im,strtag,kern,dirWeight,dirs,nmins,dirInfo=dirInfo,a=a),wavelet,mode,dims,dimOpt,dimLenOpt)[0] # Calculate the TV gradient if lam2 > 1e-6: gXFM = grads.gXFM(x,a=a) x.shape = (x.size,) return (gDataCons + lam1*gTV + lam2*gXFM).flatten() # Export the flattened array
def derivativeFunction(x, N, N_im, dims, dimOpt, dimLenOpt, lam1, lam2, data, k, strtag, ph, kern, dirWeight=0.1, dirs=None, dirInfo=[None, None, None, None], nmins=0, wavelet="db4", mode="per", a=10.): ''' This is the function that we're going to be optimizing via the scipy optimization pack. This is the function that represents Compressed Sensing ''' #import pdb; pdb.set_trace() disp = 0 gTV = 0 gXFM = 0 x.shape = N if len(N) > 2: x0 = np.zeros(N_im) for i in xrange(N[0]): x0[i, :, :] = tf.iwt(x[i, :, :], wavelet, mode, dims, dimOpt, dimLenOpt) else: x0 = tf.iwt(x, wavelet, mode, dims, dimOpt, dimLenOpt) gDataCons = tf.wt(grads.gDataCons(x0, N_im, ph, data, k), wavelet, mode, dims, dimOpt, dimLenOpt)[0] if lam1 > 1e-6: gTV = tf.wt( grads.gTV(x0, N_im, strtag, kern, dirWeight, dirs, nmins, dirInfo=dirInfo, a=a), wavelet, mode, dims, dimOpt, dimLenOpt)[0] # Calculate the TV gradient if lam2 > 1e-6: gXFM = grads.gXFM(x, a=a) x.shape = (x.size, ) return (gDataCons + lam1 * gTV + lam2 * gXFM).flatten() # Export the flattened array
def f(x, N, N_im, dims, dimOpt, dimLenOpt, lam1, lam2, data, k, strtag, ph, kern, dirWeight=0, dirs=None, dirInfo=[None,None,None,None], nmins=0, wavelet='db4', mode="per", level=3, a=10., mask=None, kmask=None): ''' This is the optimization function that we're trying to optimize. We are optimizing x here, and testing it within the funcitons that we want, as called by the functions that we've created ''' #dirInfo[0] is M tv = 0 xfm = 0 data.shape = N_im x.shape = N if len(N) > 2: x0 = np.zeros(N_im) for i in xrange(N[0]): x0[i,:,:] = tf.iwt(x[i,:,:],wavelet,mode,dims,dimOpt,dimLenOpt) else: x0 = tf.iwt(x,wavelet,mode,dims,dimOpt,dimLenOpt) obj = np.sum(objectiveFunctionDataCons(x0,N_im,ph,data,k,kmask=kmask)) if lam1 > 1e-6: tv = np.sum(objectiveFunctionTV(x0,N_im,strtag,kern,dirWeight,dirs,nmins,dirInfo=dirInfo,a=a)) if lam2 > 1e-6: xfm = (1/a)*np.log(np.cosh(a*x)) locs=np.where(np.isinf(xfm)) xfm[locs]=abs(x[locs]) xfm = np.sum(xfm) x = np.ascontiguousarray(x) x.shape = (x.size,) # Not the most efficient way to do this, but we need the shape to reset. data = np.ascontiguousarray(data) data.shape = (data.size,) #import pdb; pdb.set_trace() ##output #print('obj: %.2f' % (obj)) #print('tv: %.2f' % (lam1*tv)) #print('xfm: %.2f' % (lam2*xfm)) if np.any(np.isinf(x)) or np.any(np.isnan(x)): import pdb; pdb.set_trace() return obj + lam1*tv + lam2*xfm
def df(x, N, N_im, dims, dimOpt, dimLenOpt, lam1, lam2, data, k, strtag, ph, kern, dirWeight=0.1, dirs=None, dirInfo=[None,None,None,None], nmins=0, wavelet="db4", mode="per", level=3, a=10., mask=None, kmask=None): ''' This is the function that we're going to be optimizing via the scipy optimization pack. This is the function that represents Compressed Sensing ''' disp = 0 gTV = 0 gXFM = 0 x.shape = N if len(N) > 2: x0 = np.zeros(N_im) for i in xrange(N[0]): x0[i,:,:] = tf.iwt(x[i,:,:],wavelet,mode,dims,dimOpt,dimLenOpt) else: x0 = tf.iwt(x,wavelet,mode,dims,dimOpt,dimLenOpt) x0 = x0*mask[np.newaxis,:,:] gdc = grads.gDataCons(x0,N_im,ph,data,k,kmask=kmask) #import pdb; pdb.set_trace() if lam1 > 1e-6: gtv = grads.gTV(x0,N_im,strtag,kern,dirWeight,dirs,nmins,dirInfo=dirInfo,a=a) gDataCons = np.zeros(N) gTV = np.zeros(N) gXFM = np.zeros(N) for i in xrange(N[0]): gDataCons[i,:,:] = tf.wt(gdc[i,:,:],wavelet,mode,level,dims,dimOpt,dimLenOpt,mask)[0] if lam1 > 1e-6: gTV[i,:,:] = tf.wt(gtv[i,:,:],wavelet,mode,level,dims,dimOpt,dimLenOpt,mask)[0] # Calculate the TV gradient if lam2 > 1e-6: gXFM[i,:,:] = grads.gXFM(x[i,:,:],a=a) #import pdb; pdb.set_trace() x.shape = (x.size,) return (gDataCons + lam1*gTV + lam2*gXFM).flatten() # Export the flattened array
w_result = opt.minimize(f, w_dc, args=args, method=method, jac=df, options={'maxiter': ItnLim, 'lineSearchItnLim': lineSearchItnLim, 'gtol': 0.01, 'disp': 1, 'alpha_0': alpha_0, 'c': c, 'xtol': xtol[i], 'TVWeight': TV[i], 'XFMWeight': XFM[i], 'N': N}) if np.any(np.isnan(w_result['x'])): print('Some nan''s found. Dropping TV and XFM values') elif w_result['status'] != 0: print('TV and XFM values too high -- no solution found. Dropping...') else: w_dc = w_result['x'] stps.append(w_dc) tvStps.append(TV[i]) w_res = w_dc.reshape(N) im_res = np.zeros(N_im) for i in xrange(N[0]): im_res[i,:,:] = tf.iwt(w_res[i,:,:],wavelet,mode,dims,dimOpt,dimLenOpt) ims.append(im_res) im_stps = np.zeros([len(stps), N_im[-2], N_im[-1]]) gtv = np.zeros([len(stps), N_im[-2], N_im[-1]]) gxfm = np.zeros([len(stps), N_im[-2], N_im[-1]]) gdc = np.zeros([len(stps), N_im[-2], N_im[-1]]) for jj in range(len(stps)): im_stps[jj,:,:] = tf.iwt(stps[jj].reshape(N[-2:]),wavelet,mode,dims,dimOpt,dimLenOpt) gtv[jj,:,:] = grads.gTV(im_stps[jj,:,:].reshape(N_im),N_im,strtag, kern, 0, a=a) gxfm[jj,:,:] = tf.iwt(grads.gXFM(stps[jj].reshape(N[-2:]),a=a),wavelet,mode,dims,dimOpt,dimLenOpt) gdc[jj,:,:] = grads.gDataCons(im_stps[jj,:,:], N_im, ph_scan, data, k) for i in xrange(len(stps)):
def runCSAlgorithm(fromfid=False, filename='/home/asalerno/Documents/pyDirectionCompSense/brainData/P14/data/fullySampledBrain.npy', sliceChoice=150, strtag = ['','spatial', 'spatial'], xtol = [1e-2, 1e-3, 5e-4, 5e-4], TV = [0.01, 0.005, 0.002, 0.001], XFM = [0.01,.005, 0.002, 0.001], dirWeight=0, pctg=0.25, radius=0.2, P=2, pft=False, ext=0.5, wavelet='db4', mode='per', method='CG', ItnLim=30, lineSearchItnLim=30, alpha_0=0.6, c=0.6, a=10.0, kern = np.array([[[ 0., 0., 0.], [ 0., 0., 0.], [ 0., 0., 0.]], [[ 0., 0., 0.], [ 0., -1., 0.], [ 0., 1., 0.]], [[ 0., 0., 0.], [ 0., -1., 1.], [ 0., 0., 0.]]]), dirFile = None, nmins = None, dirs = None, M = None, dirInfo = [None]*4, saveNpy=False, saveNpyFile=None, saveImsPng=False, saveImsPngFile=None, saveImDiffPng=False, saveImDiffPngFile=None, disp=False): ##import pdb; pdb.set_trace() if fromfid==True: inputdirectory=filename[0] petable=filename[1] fullImData = rff.getDataFromFID(petable,inputdirectory,2)[0,:,:,:] fullImData = fullImData/np.max(abs(fullImData)) im = fullImData[:,:,sliceChoice] else: im = np.load(filename)[sliceChoice,:,:] N = np.array(im.shape) # image Size pdf = samp.genPDF(N[-2:], P, pctg, radius=radius, cyl=np.hstack([1, N[-2:]]), style='mult', pft=pft, ext=ext) if pft: print('Partial Fourier sampling method used') k = samp.genSampling(pdf, 50, 2)[0].astype(int) if len(N) == 2: N = np.hstack([1, N]) k = k.reshape(N) im = im.reshape(N) elif (len(N) == 3) and ('dir' not in strtag): k = k.reshape(np.hstack([1,N[-2:]])).repeat(N[0],0) ph_ones = np.ones(N[-2:], complex) ph_scan = np.zeros(N, complex) data = np.zeros(N,complex) im_scan = np.zeros(N,complex) for i in range(N[0]): k[i,:,:] = np.fft.fftshift(k[i,:,:]) data[i,:,:] = k[i,:,:]*tf.fft2c(im[i,:,:], ph=ph_ones) # IMAGE from the "scanner data" im_scan_wph = tf.ifft2c(data[i,:,:], ph=ph_ones) ph_scan[i,:,:] = tf.matlab_style_gauss2D(im_scan_wph,shape=(5,5)) ph_scan[i,:,:] = np.exp(1j*ph_scan[i,:,:]) im_scan[i,:,:] = tf.ifft2c(data[i,:,:], ph=ph_scan[i,:,:]) #im_lr = samp.loRes(im,pctg) # ------------------------------------------------------------------ # # A quick way to look at the PSF of the sampling pattern that we use # delta = np.zeros(N[-2:]) delta[int(N[-2]/2),int(N[-1]/2)] = 1 psf = tf.ifft2c(tf.fft2c(delta,ph_ones)*k,ph_ones) # ------------------------------------------------------------------ # ## ------------------------------------------------------------------ # ## -- Currently broken - Need to figure out what's happening here. -- # ## ------------------------------------------------------------------ # #if pft: #for i in xrange(N[0]): #dataHold = np.fft.fftshift(data[i,:,:]) #kHold = np.fft.fftshift(k[i,:,:]) #loc = 98 #for ix in xrange(N[-2]): #for iy in xrange(loc,N[-1]): #dataHold[-ix,-iy] = dataHold[ix,iy].conj() #kHold[-ix,-iy] = kHold[ix,iy] ## ------------------------------------------------------------------ # pdfDiv = pdf.copy() pdfZeros = np.where(pdf==0) pdfDiv[pdfZeros] = 1 #im_scan_imag = im_scan.imag #im_scan = im_scan.real N_im = N.copy() hld, dims, dimOpt, dimLenOpt = tf.wt(im_scan[0].real,wavelet,mode) N = np.hstack([N_im[0], hld.shape]) w_scan = np.zeros(N) w_full = np.zeros(N) im_dc = np.zeros(N_im) w_dc = np.zeros(N) for i in xrange(N[0]): w_scan[i,:,:] = tf.wt(im_scan.real[i,:,:],wavelet,mode,dims,dimOpt,dimLenOpt)[0] w_full[i,:,:] = tf.wt(abs(im[i,:,:]),wavelet,mode,dims,dimOpt,dimLenOpt)[0] im_dc[i,:,:] = tf.ifft2c(data[i,:,:] / np.fft.ifftshift(pdfDiv), ph=ph_scan[i,:,:]).real.copy() w_dc[i,:,:] = tf.wt(im_dc,wavelet,mode,dims,dimOpt,dimLenOpt)[0] w_dc = w_dc.flatten() im_sp = im_dc.copy().reshape(N_im) minval = np.min(abs(im)) maxval = np.max(abs(im)) data = np.ascontiguousarray(data) imdcs = [im_dc,np.zeros(N_im),np.ones(N_im),np.random.randn(np.prod(N_im)).reshape(N_im)] imdcs[-1] = imdcs[-1] - np.min(imdcs[-1]) imdcs[-1] = imdcs[-1]/np.max(abs(imdcs[-1])) mets = ['Density Corrected','Zeros','1/2''s','Gaussian Random Shift (0,1)'] wdcs = [] for i in range(len(imdcs)): wdcs.append(tf.wt(imdcs[i][0],wavelet,mode,dims,dimOpt,dimLenOpt)[0].reshape(N)) ims = [] #print('Starting the CS Algorithm') for kk in range(len(wdcs)): w_dc = wdcs[kk] print(mets[kk]) for i in range(len(TV)): args = (N, N_im, dims, dimOpt, dimLenOpt, TV[i], XFM[i], data, k, strtag, ph_scan, kern, dirWeight, dirs, dirInfo, nmins, wavelet, mode, a) w_result = opt.minimize(f, w_dc, args=args, method=method, jac=df, options={'maxiter': ItnLim, 'lineSearchItnLim': lineSearchItnLim, 'gtol': 0.01, 'disp': 1, 'alpha_0': alpha_0, 'c': c, 'xtol': xtol[i], 'TVWeight': TV[i], 'XFMWeight': XFM[i], 'N': N}) if np.any(np.isnan(w_result['x'])): print('Some nan''s found. Dropping TV and XFM values') elif w_result['status'] != 0: print('TV and XFM values too high -- no solution found. Dropping...') else: w_dc = w_result['x'] w_res = w_dc.reshape(N) im_res = np.zeros(N_im) for i in xrange(N[0]): im_res[i,:,:] = tf.iwt(w_res[i,:,:],wavelet,mode,dims,dimOpt,dimLenOpt) ims.append(im_res) if saveNpy: if saveNpyFile is None: np.save('./holdSave_im_res_' + str(int(pctg*100)) + 'p_all_SP',ims) else: np.save(saveNpyFile,ims) if saveImsPng: vis.figSubplots(ims,titles=mets,clims=(minval,maxval),colorbar=True) if not disp: if saveImsPngFile is None: saveFig.save('./holdSave_ims_' + str(int(pctg*100)) + 'p_all_SP') else: saveFig.save(saveImsPngFile) if saveImDiffPng: imdiffs, clims = vis.imDiff(ims) diffMets = ['DC-Zeros','DC-Ones','DC-Random','Zeros-Ones','Zeros-Random','Ones-Random'] vis.figSubplots(imdiffs,titles=diffMets,clims=clims,colorbar=True) if not disp: if saveImDiffPngFile is None: saveFig.save('./holdSave_im_diffs_' + str(int(pctg*100)) + 'p_all_SP') else: saveFig.save(saveImDiffPngFile) if disp: plt.show()
im_dc = np.zeros(N_im) w_dc = np.zeros(N) strftime("%Y-%m-%d %H:%M:%S", localtime()) args = (N, N_im, np.prod(N_im), dims, dimOpt, dimLenOpt, TV[0], XFM[0], data, k, strtag, ph_scan, kern, dirWeight, dirs, dirInfo, nmins, wavelet, mode, a) w_result = opt.fmin_tnc(f, w_dc.flat, fprime=df, args=args, accuracy=1e-4, disp=5) wHold = w_result[0].copy().reshape(N) imHold = np.zeros(N_im) for i in xrange(N[0]): imHold[i,:,:] = tf.iwt(wHold[i,:,:],wavelet,mode,dims,dimOpt,dimLenOpt) strftime("%Y-%m-%d %H:%M:%S", localtime()) ##pdfDiv = pdf.copy() ##pdfZeros = np.where(pdf==0) ##pdfDiv[pdfZeros] = 1 ##im_scan_imag = im_scan.imag ##im_scan = im_scan #x, y = np.meshgrid(np.linspace(-1,1,N[-1]),np.linspace(-1,1,N[-2])) #locs = (abs(x)<=radius) * (abs(y)<=radius)
args = (N, N_im, np.prod(N_im), dims, dimOpt, dimLenOpt, TV[0], XFM[0], data, k, strtag, ph_scan, kern, dirWeight, dirs, dirInfo, nmins, wavelet, mode, a) w_result = opt.fmin_tnc(f, w_dc.flat, fprime=df, args=args, accuracy=1e-4, disp=5) wHold = w_result[0].copy().reshape(N) imHold = np.zeros(N_im) for i in xrange(N[0]): imHold[i, :, :] = tf.iwt(wHold[i, :, :], wavelet, mode, dims, dimOpt, dimLenOpt) strftime("%Y-%m-%d %H:%M:%S", localtime()) ##pdfDiv = pdf.copy() ##pdfZeros = np.where(pdf==0) ##pdfDiv[pdfZeros] = 1 ##im_scan_imag = im_scan.imag ##im_scan = im_scan #x, y = np.meshgrid(np.linspace(-1,1,N[-1]),np.linspace(-1,1,N[-2])) #locs = (abs(x)<=radius) * (abs(y)<=radius) #minLoc = np.min(np.where(locs==True)) #pctgSamp = np.zeros(minLoc+1) #for i in range(1,minLoc+1):
#stps.append(w_dc) #w_stp.append(w_dc.reshape(NSub)) #im_hld = np.zeros(N_imSub) #for i in range(NSub[0]): #im_hld[i] = tf.iwt(w_stp[-1][i],wavelet,mode,dimsSub,dimOptSub,dimLenOptSub) #imStp.append(im_hld) #plt.imshow(imStp[-1],clim=(minval,maxval)); plt.colorbar(); plt.show() #w_dc = w_stp[k].flatten() #stps.append(w_dc) #wdcHold = w_dc.reshape(NSub) #dataStp = np.fft.fftshift(tf.fft2c(imStp[-1],ph_scanSub),axes=(-2,-1)) #kStp = np.fft.fftshift(kSub,axes=(-2,-1)).copy() #kMaskRpt = kMasked.reshape(np.hstack([1,N_imSub[-2:]])).repeat(N_imSub[0],0) im_hld = np.zeros(N_imSub) for i in range(NSub[0]): im_hld[i] = tf.iwt(w_dc[i], wavelet, mode, dimsSub, dimOptSub, dimLenOptSub) data_dc = tf.fft2c(im_hld, ph=ph_scanSub, axes=(-2, -1)) kMasked = (np.floor(1 - kMasked) * pctgSamp[locSteps[j]] * lam_trust + kMasked) #kMasked = (np.floor(1-kMasked)*pctgSamp[locSteps[j]]*1.0 + kMasked).reshape(np.hstack([1, N_imSub[-2:]])) wHold = w_dc.copy().reshape(NSub) imHold = np.zeros(N_imSub) for i in xrange(N[0]): imHold[i, :, :] = tf.iwt(wHold[i, :, :], wavelet, mode, dimsSub, dimOptSub, dimLenOptSub) np.save( 'temp/' + str(int(100 * pctg)) + '_slice_' + str(sliceChoice) + '_TV_' + str(TV) + '_XFM_' + str(XFM) + '_lamTrust_' +
'N': N }) if np.any(np.isnan(w_result['x'])): print('Some nan' 's found. Dropping TV and XFM values') elif w_result['status'] != 0: print( 'TV and XFM values too high -- no solution found. Dropping...') else: w_dc = w_result['x'] stps.append(w_dc) tvStps.append(TV[i]) w_res = w_dc.reshape(N) im_res = np.zeros(N_im) for i in xrange(N[0]): im_res[i, :, :] = tf.iwt(w_res[i, :, :], wavelet, mode, dims, dimOpt, dimLenOpt) ims.append(im_res) im_stps = np.zeros([len(stps), N_im[-2], N_im[-1]]) gtv = np.zeros([len(stps), N_im[-2], N_im[-1]]) gxfm = np.zeros([len(stps), N_im[-2], N_im[-1]]) gdc = np.zeros([len(stps), N_im[-2], N_im[-1]]) for jj in range(len(stps)): im_stps[jj, :, :] = tf.iwt(stps[jj].reshape(N[-2:]), wavelet, mode, dims, dimOpt, dimLenOpt) gtv[jj, :, :] = grads.gTV(im_stps[jj, :, :].reshape(N_im), N_im, strtag, kern, 0, a=a)
#stps.append(w_dc) #wdcHold = w_dc.reshape(NSub) #kMasked = np.floor(1-kMasked)*pctgSamp[locSteps[j]] + kMasked #if j == len(TV): #print('No solution found on final run. Saving last spot.') #w_dc = w_result['x'] ##import pdb; pdb.set_trace() #stps.append(w_dc) #tvStps.append(TV[i]) #w_stp.append(w_dc.reshape(NSub)) #imStp.append(tf.iwt(w_stp[-1][0],wavelet,mode,dimsSub,dimOptSub,dimLenOptSub)) ##plt.imshow(imStp[-1]); plt.colorbar(); plt.show() ##w_dc = w_stp[k].flatten() #stps.append(w_dc) #wdcHold = w_dc.reshape(NSub) ##dataStp = np.fft.fftshift(tf.fft2c(imStp[-1],ph_scanSub)) ##kStp = np.fft.fftshift(kSub).copy() #kMasked = np.floor(1-kMasked)*pctgSamp[locSteps[j]] + kMasked w_dc = w_result['x'] #import pdb; pdb.set_trace() stps.append(w_dc) tvStps.append(TV[i]) w_stp.append(w_dc.reshape(NSub)) imStp.append(tf.iwt(w_stp[-1][0],wavelet,mode,dimsSub,dimOptSub,dimLenOptSub)) #plt.imshow(imStp[-1]); plt.colorbar(); plt.show() #w_dc = w_stp[k].flatten() stps.append(w_dc) wdcHold = w_dc.reshape(NSub) kMasked = np.floor(1-kMasked)*pctgSamp[locSteps[j]] + kMasked dataStp = np.fft.fftshift(tf.fft2c(imStp[-1],ph_scanSub)) kStp = np.fft.fftshift(kSub).copy()
def runCSAlgorithm( fromfid=False, filename='/home/asalerno/Documents/pyDirectionCompSense/brainData/P14/data/fullySampledBrain.npy', sliceChoice=150, strtag=['', 'spatial', 'spatial'], xtol=[1e-2, 1e-3, 5e-4, 5e-4], TV=[0.01, 0.005, 0.002, 0.001], XFM=[0.01, .005, 0.002, 0.001], dirWeight=0, pctg=0.25, radius=0.2, P=2, pft=False, ext=0.5, wavelet='db4', mode='per', method='CG', ItnLim=30, lineSearchItnLim=30, alpha_0=0.6, c=0.6, a=10.0, kern=np.array([[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., -1., 0.], [0., 1., 0.]], [[0., 0., 0.], [0., -1., 1.], [0., 0., 0.]]]), dirFile=None, nmins=None, dirs=None, M=None, dirInfo=[None] * 4, saveNpy=False, saveNpyFile=None, saveImsPng=False, saveImsPngFile=None, saveImDiffPng=False, saveImDiffPngFile=None, disp=False): ##import pdb; pdb.set_trace() if fromfid == True: inputdirectory = filename[0] petable = filename[1] fullImData = rff.getDataFromFID(petable, inputdirectory, 2)[0, :, :, :] fullImData = fullImData / np.max(abs(fullImData)) im = fullImData[:, :, sliceChoice] else: im = np.load(filename)[sliceChoice, :, :] N = np.array(im.shape) # image Size pdf = samp.genPDF(N[-2:], P, pctg, radius=radius, cyl=np.hstack([1, N[-2:]]), style='mult', pft=pft, ext=ext) if pft: print('Partial Fourier sampling method used') k = samp.genSampling(pdf, 50, 2)[0].astype(int) if len(N) == 2: N = np.hstack([1, N]) k = k.reshape(N) im = im.reshape(N) elif (len(N) == 3) and ('dir' not in strtag): k = k.reshape(np.hstack([1, N[-2:]])).repeat(N[0], 0) ph_ones = np.ones(N[-2:], complex) ph_scan = np.zeros(N, complex) data = np.zeros(N, complex) im_scan = np.zeros(N, complex) for i in range(N[0]): k[i, :, :] = np.fft.fftshift(k[i, :, :]) data[i, :, :] = k[i, :, :] * tf.fft2c(im[i, :, :], ph=ph_ones) # IMAGE from the "scanner data" im_scan_wph = tf.ifft2c(data[i, :, :], ph=ph_ones) ph_scan[i, :, :] = tf.matlab_style_gauss2D(im_scan_wph, shape=(5, 5)) ph_scan[i, :, :] = np.exp(1j * ph_scan[i, :, :]) im_scan[i, :, :] = tf.ifft2c(data[i, :, :], ph=ph_scan[i, :, :]) #im_lr = samp.loRes(im,pctg) # ------------------------------------------------------------------ # # A quick way to look at the PSF of the sampling pattern that we use # delta = np.zeros(N[-2:]) delta[int(N[-2] / 2), int(N[-1] / 2)] = 1 psf = tf.ifft2c(tf.fft2c(delta, ph_ones) * k, ph_ones) # ------------------------------------------------------------------ # ## ------------------------------------------------------------------ # ## -- Currently broken - Need to figure out what's happening here. -- # ## ------------------------------------------------------------------ # #if pft: #for i in xrange(N[0]): #dataHold = np.fft.fftshift(data[i,:,:]) #kHold = np.fft.fftshift(k[i,:,:]) #loc = 98 #for ix in xrange(N[-2]): #for iy in xrange(loc,N[-1]): #dataHold[-ix,-iy] = dataHold[ix,iy].conj() #kHold[-ix,-iy] = kHold[ix,iy] ## ------------------------------------------------------------------ # pdfDiv = pdf.copy() pdfZeros = np.where(pdf == 0) pdfDiv[pdfZeros] = 1 #im_scan_imag = im_scan.imag #im_scan = im_scan.real N_im = N.copy() hld, dims, dimOpt, dimLenOpt = tf.wt(im_scan[0].real, wavelet, mode) N = np.hstack([N_im[0], hld.shape]) w_scan = np.zeros(N) w_full = np.zeros(N) im_dc = np.zeros(N_im) w_dc = np.zeros(N) for i in xrange(N[0]): w_scan[i, :, :] = tf.wt(im_scan.real[i, :, :], wavelet, mode, dims, dimOpt, dimLenOpt)[0] w_full[i, :, :] = tf.wt(abs(im[i, :, :]), wavelet, mode, dims, dimOpt, dimLenOpt)[0] im_dc[i, :, :] = tf.ifft2c(data[i, :, :] / np.fft.ifftshift(pdfDiv), ph=ph_scan[i, :, :]).real.copy() w_dc[i, :, :] = tf.wt(im_dc, wavelet, mode, dims, dimOpt, dimLenOpt)[0] w_dc = w_dc.flatten() im_sp = im_dc.copy().reshape(N_im) minval = np.min(abs(im)) maxval = np.max(abs(im)) data = np.ascontiguousarray(data) imdcs = [ im_dc, np.zeros(N_im), np.ones(N_im), np.random.randn(np.prod(N_im)).reshape(N_im) ] imdcs[-1] = imdcs[-1] - np.min(imdcs[-1]) imdcs[-1] = imdcs[-1] / np.max(abs(imdcs[-1])) mets = [ 'Density Corrected', 'Zeros', '1/2' 's', 'Gaussian Random Shift (0,1)' ] wdcs = [] for i in range(len(imdcs)): wdcs.append( tf.wt(imdcs[i][0], wavelet, mode, dims, dimOpt, dimLenOpt)[0].reshape(N)) ims = [] #print('Starting the CS Algorithm') for kk in range(len(wdcs)): w_dc = wdcs[kk] print(mets[kk]) for i in range(len(TV)): args = (N, N_im, dims, dimOpt, dimLenOpt, TV[i], XFM[i], data, k, strtag, ph_scan, kern, dirWeight, dirs, dirInfo, nmins, wavelet, mode, a) w_result = opt.minimize(f, w_dc, args=args, method=method, jac=df, options={ 'maxiter': ItnLim, 'lineSearchItnLim': lineSearchItnLim, 'gtol': 0.01, 'disp': 1, 'alpha_0': alpha_0, 'c': c, 'xtol': xtol[i], 'TVWeight': TV[i], 'XFMWeight': XFM[i], 'N': N }) if np.any(np.isnan(w_result['x'])): print('Some nan' 's found. Dropping TV and XFM values') elif w_result['status'] != 0: print( 'TV and XFM values too high -- no solution found. Dropping...' ) else: w_dc = w_result['x'] w_res = w_dc.reshape(N) im_res = np.zeros(N_im) for i in xrange(N[0]): im_res[i, :, :] = tf.iwt(w_res[i, :, :], wavelet, mode, dims, dimOpt, dimLenOpt) ims.append(im_res) if saveNpy: if saveNpyFile is None: np.save('./holdSave_im_res_' + str(int(pctg * 100)) + 'p_all_SP', ims) else: np.save(saveNpyFile, ims) if saveImsPng: vis.figSubplots(ims, titles=mets, clims=(minval, maxval), colorbar=True) if not disp: if saveImsPngFile is None: saveFig.save('./holdSave_ims_' + str(int(pctg * 100)) + 'p_all_SP') else: saveFig.save(saveImsPngFile) if saveImDiffPng: imdiffs, clims = vis.imDiff(ims) diffMets = [ 'DC-Zeros', 'DC-Ones', 'DC-Random', 'Zeros-Ones', 'Zeros-Random', 'Ones-Random' ] vis.figSubplots(imdiffs, titles=diffMets, clims=clims, colorbar=True) if not disp: if saveImDiffPngFile is None: saveFig.save('./holdSave_im_diffs_' + str(int(pctg * 100)) + 'p_all_SP') else: saveFig.save(saveImDiffPngFile) if disp: plt.show()
#stps.append(w_dc) #w_stp.append(w_dc.reshape(NSub)) #im_hld = np.zeros(N_imSub) #for i in range(NSub[0]): #im_hld[i] = tf.iwt(w_stp[-1][i],wavelet,mode,dimsSub,dimOptSub,dimLenOptSub) #imStp.append(im_hld) #plt.imshow(imStp[-1],clim=(minval,maxval)); plt.colorbar(); plt.show() #w_dc = w_stp[k].flatten() #stps.append(w_dc) #wdcHold = w_dc.reshape(NSub) #dataStp = np.fft.fftshift(tf.fft2c(imStp[-1],ph_scanSub),axes=(-2,-1)) #kStp = np.fft.fftshift(kSub,axes=(-2,-1)).copy() #kMaskRpt = kMasked.reshape(np.hstack([1,N_imSub[-2:]])).repeat(N_imSub[0],0) im_hld = np.zeros(N_imSub) for i in range(NSub[0]): im_hld[i] = tf.iwt(w_dc[i],wavelet,mode,dimsSub,dimOptSub,dimLenOptSub) data_dc = tf.fft2c(im_hld, ph=ph_scanSub, axes=(-2,-1)) kMasked = (np.floor(1-kMasked)*pctgSamp[locSteps[j]]*lam_trust + kMasked) #kMasked = (np.floor(1-kMasked)*pctgSamp[locSteps[j]]*1.0 + kMasked).reshape(np.hstack([1, N_imSub[-2:]])) wHold = w_dc.copy().reshape(NSub) imHold = np.zeros(N_imSub) for i in xrange(N[0]): imHold[i,:,:] = tf.iwt(wHold[i,:,:],wavelet,mode,dimsSub,dimOptSub,dimLenOptSub) np.save('/hpf/largeprojects/MICe/asalerno/pyDirectionCompSense/tests/fullBrainTests/' + str(int(100*pctg)) + '_3_spatial_TV_im_final' + str(int(nSteps)) + '_comb_ks.npy',imHold) np.save('/hpf/largeprojects/MICe/asalerno/pyDirectionCompSense/tests/fullBrainTests/' + str(int(100*pctg)) + '_3_spatial_TV_im_final' + str(int(nSteps)) + '_comb_ks.npy',wHold) #outvol = volumeFromData('/hpf/largeprojects/MICe/asalerno/pyDirectionCompSense/tests/fullBrainTests/' + str(int(100*pctg)) + '_3_spatial_TV_im_final_' + str(int(nSteps)) + '_comb_ks.mnc', imHold, dimnames=['xspace','yspace','zspace'], starts=(0, 0, 0), steps=(1, 1, 1), volumeType="uint")
def objectiveFunction(x, N, N_im, sz, dims, dimOpt, dimLenOpt, lam1, lam2, data, k, strtag, ph, kern, dirWeight=0, dirs=None, dirInfo=[None, None, None, None], nmins=0, wavelet='db4', mode="per", a=10.): ''' This is the optimization function that we're trying to optimize. We are optimizing x here, and testing it within the funcitons that we want, as called by the functions that we've created ''' #dirInfo[0] is M tv = 0 xfm = 0 data.shape = N_im x.shape = N if len(N) > 2: x0 = np.zeros(N_im, complex) for i in xrange(N[0]): x0[i, :, :] = tf.iwt(x[i, :, :], wavelet, mode, dims, dimOpt, dimLenOpt) else: x0 = tf.iwt(x, wavelet, mode, dims, dimOpt, dimLenOpt) obj = np.sum(objectiveFunctionDataCons(x0, N_im, ph, data, k, sz, strtag)).real if lam1 > 1e-6: tv = np.sum( abs( objectiveFunctionTV(x0, N_im, strtag, kern, dirWeight, dirs, nmins, dirInfo=dirInfo, a=a))) if lam2 > 1e-6: xfm = np.sum(abs((1 / a) * np.log(np.cosh(a * x)))) x.shape = ( x.size, ) # Not the most efficient way to do this, but we need the shape to reset. data.shape = (data.size, ) #import pdb; pdb.set_trace() ###output #print('obj: %.2f' % (obj)) #print('tv: %.2f' % (lam1*tv)) #print('xfm: %.2f' % (lam2*xfm)) return abs(obj + lam1 * tv + lam2 * xfm)