def ca_deconvolution(ddf_trace, l0=False): """ perform calcium image deconvolution This function performs several calcium image deconvolution approaches. Deconvolutions currently supported: 1. OASIS (https://github.com/j-friedrich/OASIS) 2. Event detection script from Peter 3. AR-FPOP input: ddf_trace: a 1d-numpy array of length n (the number of time steps in the calcium trace) output: a dictionary whose keys are the deconvolution method used and values are a 1d-numpy array of length n with the estimated spikes TODO: Add functionality for the following methods 4. ML Spike 5. One of the supervised methods? """ out = {} # Method 1 OASIS (https://github.com/j-friedrich/OASIS) c, s, b, g, lam = deconvolve(np.double(ddf_trace), penalty=1) out['OASIS'] = s # Method 2 event detection yes_array, size_array = ed.get_events_derivative(ddf_trace) times_new, heights_new = ed.concatenate_adjacent_events(yes_array, size_array, delta=3) tmp = np.zeros_like(ddf_trace) tmp[times_new] = heights_new out['event_detection'] = tmp # Method 3 FastLZeroSpikeInference if (l0): import arfpop_ctypes as af # Some default parameters that need to be tuned! gam = 0.99 penalty = 0.25 constraint = False ar_fit = af.arfpop(ddf_trace, gam, penalty, constraint) out['arfpop'] = ar_fit['pos_spike_mag'] return out
def ca_deconvolution(ddf_trace): """ perform calcium image deconvolution This function performs several calcium image deconvolution approaches. Deconvolutions currently supported: 1. OASIS (https://github.com/j-friedrich/OASIS) 2. Event detection script from Peter input: ddf_trace: a 1d-numpy array of length n (the number of time steps in the calcium trace) output: a dictionary whose keys are the deconvolution method used and values are a 1d-numpy array of length n with the estimated spikes TODO: Add functionality for the following methods 3. AR-FPOP 4. ML Spike 5. One of the supervised methods? """ out = {} # Method 1 OASIS (https://github.com/j-friedrich/OASIS) c, s, b, g, lam = deconvolve(np.double(ddf_trace), penalty=1) out['OASIS'] = s # Method 2 event detection yes_array, size_array = ed.get_events_derivative(ddf_trace) times_new, heights_new = ed.concatenate_adjacent_events(yes_array, size_array, delta=3) tmp = np.zeros_like(ddf_trace) tmp[times_new] = heights_new out['event_detection'] = tmp return out
def PlotAll(SaveNames,params): from numpy import min, max, percentile,asarray,ceil,sqrt import numpy as np import sys from scipy.signal import welch from pylab import load import matplotlib.pyplot as plt import matplotlib.animation as animation from matplotlib.backends.backend_pdf import PdfPages # from scipy.ndimage.measurements import label from AuxilaryFunctions import GetRandColors, max_intensity,SuperVoxelize,GetData,PruneComponents,SplitComponents,ThresholdShapes,MergeComponents,ThresholdData,make_sure_path_exists # from BlockLocalNMF_AuxilaryFunctions import HALS4activity from mpl_toolkits.axes_grid1 import make_axes_locatable # makse sure relevant folders exist, and add to path Results_folder='Results/' make_sure_path_exists(Results_folder) OASIS_path='OASIS/' make_sure_path_exists(OASIS_path) sys.path.append(OASIS_path) from functions import deconvolve ## plotting params # what to plot plot_activities=True plot_activities_PSD=False plot_shapes_projections=True plot_shapes_slices=False plot_activityCorrs=False plot_clustered_shape=False plot_residual_slices=False plot_residual_projections=False # videos to generate video_shapes=False video_residual=True video_slices=False # what to save save_video=True save_plot=True close_figs=True#close all figs right after saving (to avoid memory overload) # PostProcessing Split=False Threshold=False #threshold shapes in the end and keep only connected components Prune=False # Remove "Bad" components (where bad is defined within SplitComponent fucntion) Merge=True # Merge highly correlated nearby components FineTune=False # SHould we fine tune activity after post-processing? (mainly after merging) IncludeBackground=False #should we include the background as an extracted component? # how to plot detrend=True #should we detrend the data (remove background component)? scale=2 #scale colormap to enhance colors satuartion_percentile=96 #saturate colormap ont this percentile, when ma=percentile is used dpi=200 #for videos restrict_support=True #in shape video, zero out data outside support of shapes C=4 #number of components to show in shape videos (if larger then number of shapes L, then we automatically set C=L) color_map='gray' #'gnuplot' frame_rate=10.0 #Hz # Fetch experimental 3D data data=GetData(params.data_name) if params.SuperVoxelize==True: data=SuperVoxelize(data) if params.ThresholdData==True: data=ThresholdData(data) dims=np.shape(data) if len(dims)<4: plot_Shapes2D=False video_residual_2D=False if plot_shapes_projections or plot_shapes_slices: plot_Shapes2D=True if video_residual==True: video_residual_2D=True plot_shapes_projections=False plot_shapes_slices=False plot_activityCorrs=False plot_clustered_shape=False plot_residual_slices=False plot_residual_projections=False video_shapes=False video_residual=False video_slices=False print('2D data, ignoring 3D plots/video options') else: plot_Shapes2D=False video_residual_2D=False min_dim=np.argmin(dims[1:]) denoised_data=0 detrended_data=data for rep in range(len(SaveNames)): resultsName=SaveNames[rep] try: results=load('NMF_Results/'+SaveNames[rep]) except IOError: if rep==0: print('results file not found!!') else: break SS=results['shapes'] AA=results['activity'] if rep>=params.Background_num: adaptBias=False else: adaptBias=True if IncludeBackground==True: adaptBias=False L=len(AA)-adaptBias if L==0: #Stop if we encounter a file with zero components break S=SS[:-adaptBias] b=SS[L:(L+adaptBias)] A=AA[:-adaptBias] f=AA[L:(L+adaptBias)] if rep==0: shapes=S activity=A background_shapes=b background_activity=f else: shapes=np.append(shapes,S,axis=0) activity=np.append(activity,A,axis=0) background_shapes=np.append(background_shapes,b,axis=0) background_activity=np.append(background_activity,f,axis=0) L=len(shapes) adaptBias=0 if Split==True: shapes,activity,L,all_local_max=SplitComponents(shapes,activity,adaptBias) if Merge==True: shapes,activity,L=MergeComponents(shapes,activity,L,threshold=0.7,sig=10) if Prune==True: # deleted_indices=[5,9,11,14,15,17,24]+range(25,36) shapes,activity,L=PruneComponents(shapes,activity,L,params.TargetAreaRatio) activity_NonNegative=np.copy(activity) activity_NonNegative[activity_NonNegative<0]=0 activity_noisy=np.copy(activity_NonNegative) if FineTune==True: for ll in range(L): activity[ll], spikes, baseline, g, lam = deconvolve(activity_NonNegative[ll],optimize_g=10,penalty=0) # activity,background_activity,S,bl,c1,sn,g,junk = update_temporal_components(data.reshape((len(data),-1)).transpose(), shapes.reshape((len(shapes),-1)).transpose(), background_shapes.reshape((len(background_shapes),-1)).transpose(), activity,background_activity,**options['temporal_params']) activity_noisy=np.copy(activity_NonNegative) activity_NonNegative=activity print(str(L)+' shapes detected') detrended_data= detrended_data - background_activity.T.dot(background_shapes.reshape((len(background_shapes), -1))).reshape(dims) if Threshold==True: shapes=ThresholdShapes(shapes,adaptBias,[],MaxRatio=[]) if plot_residual_projections or video_shapes or video_residual or video_slices or video_residual_2D: colors=GetRandColors(L) color_shapes=np.transpose(shapes.reshape(L, -1,1)*colors,[1,0,2]) #weird transpose for tensor dot product next line denoised_data = denoised_data + (activity_NonNegative.T.dot(color_shapes)).reshape(tuple(dims)+(3,)) residual = detrended_data - activity_NonNegative.T.dot(shapes.reshape(L, -1)).reshape(dims) if detrend==True: data=detrended_data #%% After loading loop - Normalize (colored) denoised data if plot_residual_projections or video_shapes or video_residual or video_slices or video_residual_2D: # denoised_data=denoised_data/np.max(denoised_data) denoised_data=old_div(denoised_data,np.percentile(denoised_data[denoised_data>0],99.5)) #%% normalize denoised data range denoised_data[denoised_data>1]=1 # plt.close('all') #%% plotting params ComponentsInFig=20 left = 0.05 # the left side of the subplots of the figure right = 0.95 # the right side of the subplots of the figure bottom = 0.05 # the bottom of the subplots of the figure top = 0.95 # the top of the subplots of the figure wspace = 0.1 # the amount of width reserved for blank space between subplots hspace = 0.12 # the amount of height reserved for white space between subplots #%% ###### Plot Individual neurons' activities index=0 #component display index sz=np.min([ComponentsInFig,L+adaptBias]) # a=ceil(sqrt(sz)) # b=ceil(sz/a) a=sz b=1 if plot_activities: pp = PdfPages(Results_folder + 'Activities'+resultsName+'.pdf') for ii in range(L+adaptBias): if index==0: # fig0=plt.figure(figsize=(dims[1] , dims[2])) fig0=plt.figure(figsize=(11,18)) ax = plt.subplot(a,b,index+1) # dt=1/30 # 30 Hz sample rate time=list(range(len(activity[ii]))) plt.plot(time,activity_noisy[ii],linewidth=0.5,c='r') plt.plot(time,activity[ii],linewidth=3,c='b') ma=np.max([np.max(activity[ii]),np.max(activity_noisy[ii])]) plt.setp(ax, xticks=[],yticks=[0,ma]) # component number ax.text(0.02, 0.8, str(ii), verticalalignment='bottom', horizontalalignment='left', transform=ax.transAxes, color='black',weight='bold', fontsize=13) index+=1 if ((ii%ComponentsInFig)==(ComponentsInFig-1)) or ii==(L+adaptBias-1): index=0 if save_plot==True: plt.subplots_adjust(left*2, bottom, right, top, wspace, hspace*2) pp.savefig(fig0) pp.close() if close_figs: plt.close('all') #%% Plot activities` PSDs index=0 #component display index sz=np.min([ComponentsInFig,L+adaptBias]) a=ceil(sqrt(sz)) b=ceil(old_div(sz,a)) if plot_activities_PSD: pp = PdfPages(Results_folder + 'ActivityPSDs'+resultsName+'.pdf') for ii in range(L+adaptBias): if index==0: # fig0=plt.figure(figsize=(dims[1] , dims[2])) fig0=plt.figure(figsize=(11,18)) ax = plt.subplot(a,b,index+1) ff, psd_activity = welch(activity[ii], nperseg=round(old_div(len(activity[ii]), 64))) plt.plot(ff,psd_activity,linewidth=3) plt.setp(ax, xticks=[],yticks=[0]) # component number ax.text(0.02, 0.8, str(ii), verticalalignment='bottom', horizontalalignment='left', transform=ax.transAxes, color='black',weight='bold', fontsize=13) index+=1 if ((ii%ComponentsInFig)==(ComponentsInFig-1)) or ii==(L+adaptBias-1): index=0 if save_plot==True: plt.subplots_adjust(left, bottom, right, top, wspace, hspace) pp.savefig(fig0) pp.close() if close_figs: plt.close('all') #%% 2D shapes index=0 #component display index sz=np.min([ComponentsInFig,L+adaptBias]) a=ceil(0.5*sqrt(sz)) b=ceil(old_div(sz,a)) if plot_Shapes2D: if save_plot==True: pp = PdfPages(Results_folder + 'Shapes2D_'+resultsName+'.pdf') for ll in range(L+adaptBias): if index==0: fig=plt.figure(figsize=(18 , 11)) ax = plt.subplot(a,b,index+1) temp=shapes[ll] mi=0 try: ma=np.percentile(temp[temp>0],satuartion_percentile) except IndexError: ma=0 im=plt.imshow(temp,vmin=mi,vmax=ma,cmap=color_map) plt.setp(ax,xticks=[],yticks=[]) mn=int(np.floor(mi)) # colorbar min value mx=int(np.ceil(ma)) # colorbar max value md=old_div((mx-mn),2) # divider = make_axes_locatable(ax) # cax = divider.append_axes("right", size="5%", pad=0.05) # cb=plt.colorbar(im,cax=cax) # cb.set_ticks([mn,md,mx]) # cb.set_ticklabels([mn,md,mx]) # component number ax.text(0.02, 0.8, str(ll), verticalalignment='bottom', horizontalalignment='left', transform=ax.transAxes, color='white',weight='bold', fontsize=13) #sparsity spar_str=str(np.round(np.mean(shapes[ll]>0)*100,2))+'%' ax.text(0.02, 0.02, spar_str, verticalalignment='bottom', horizontalalignment='left', transform=ax.transAxes, color='white',weight='bold', fontsize=13) #L^p for p in range(2,6,2): Lp=old_div((np.sum(shapes[ll]**p))**(old_div(1,float(p))),np.sum(shapes[ll])) Lp_str=str(np.round(Lp*100,2))+'%' #'L'+str(p)+'='+ ax.text(0.02+p*0.2, 0.02, Lp_str, verticalalignment='bottom', horizontalalignment='left', transform=ax.transAxes, color='yellow',weight='bold', fontsize=13) index+=1 if (ll%ComponentsInFig==(ComponentsInFig-1)) or ll==L+adaptBias-1: plt.subplots_adjust(left, bottom, right, top, wspace, hspace) index=0 if save_plot==True: pp.savefig(fig) pp.close() if close_figs: plt.close('all') #%% Re-write plot code from here, so that each figure has only ComponentsInFig components #%% ###### Plot Individual neurons' area which is correlated with their activities a=ceil(sqrt(L+adaptBias)) b=ceil(old_div((L+adaptBias),a)) if plot_activityCorrs: if save_plot==True: pp = PdfPages(Results_folder + 'CorrelationWithActivity'+resultsName+'.pdf') for dd in range(len(shapes[0].shape)): fig0=plt.figure(figsize=(11,18)) for ii in range(L+adaptBias): ax = plt.subplot(a,b,ii+1) corr_imag=old_div(np.dot(activity[ii],np.transpose(data,[1,2,0,3])),np.sqrt(np.sum(data**2,axis=0)*np.sum(activity[ii]**2))) plt.imshow(np.abs(corr_imag).max(dd),cmap=color_map) plt.setp(ax,xticks=[],yticks=[]) plt.subplots_adjust(left, bottom, right, top, wspace, hspace) if save_plot==True: pp.savefig(fig0) pp.close() if close_figs: plt.close('all') #%% All Shapes projections a=ceil(sqrt(L+adaptBias)) b=ceil(old_div((L+adaptBias),a)) if plot_shapes_projections: if save_plot==True: pp = PdfPages(Results_folder + 'Shapes_projections'+resultsName+'.pdf') for dd in range(len(shapes[0].shape)): fig=plt.figure(figsize=(18 , 11)) for ll in range(L+adaptBias): ax = plt.subplot(a,b,ll+1) temp=shapes[ll].max(dd) if dd==2: temp=temp.T mi=np.min(shapes[ll]) ma=np.max(shapes[ll]) im=plt.imshow(temp,vmin=mi,vmax=ma,cmap=color_map) plt.setp(ax,xticks=[],yticks=[]) mn=int(np.floor(mi)) # colorbar min value mx=int(np.ceil(ma)) # colorbar max value md=old_div((mx-mn),2) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) cb=plt.colorbar(im,cax=cax) # cb.set_ticks([mn,md,mx]) # cb.set_ticklabels([mn,md,mx]) # component number ax.text(0.02, 0.8, str(ll), verticalalignment='bottom', horizontalalignment='left', transform=ax.transAxes, color='white',weight='bold', fontsize=13) # #sparsity spar_str=str(np.round(np.mean(shapes[ll]>0)*100,2))+'%' ax.text(0.02, 0.02, spar_str, verticalalignment='bottom', horizontalalignment='left', transform=ax.transAxes, color='white',weight='bold', fontsize=13) # #L^p # for p in range(2,2,2): # Lp=(np.sum(shapes[ll]**p))**(1/float(p))/np.sum(shapes[ll]) # Lp_str=str(np.round(Lp*100,2))+'%' #'L'+str(p)+'='+ # ax.text(0.02+p*0.2, 0.02, Lp_str, # verticalalignment='bottom', horizontalalignment='left', # transform=ax.transAxes, # color='yellow',weight='bold', fontsize=13) plt.subplots_adjust(left, bottom, right, top, wspace, hspace) if save_plot==True: pp.savefig(fig) pp.close() if close_figs: plt.close('all') #for ll in range(L+adaptBias): # print 'Sparsity=',np.mean(shapes[ll]>0) #%% All Shapes slices transpose_shape= True # should we transpose shape ComponentsInFig=3 # number of components in Figure index=0 #component display index # z_slices=[0,1,2,3,4,5,6,7,8] #which z slices to look at slice plots/videos z_slices=list(range(dims[min_dim+1])) #which z slices to look at slice plots/videos if plot_shapes_slices: if save_plot==True: pp = PdfPages(Results_folder + 'Shapes_slices'+resultsName+'.pdf') for ll in range(L+adaptBias): if index==0: fig=plt.figure(figsize=(18, 11)) for dd in range(len(z_slices)): ax = plt.subplot(ComponentsInFig,len(z_slices),index*len(z_slices)+dd+1) temp=shapes[ll].take(dd,axis=min_dim) if transpose_shape: temp=np.transpose(temp) mi=np.min(shapes[ll]) ma=np.max(shapes[ll]) im=plt.imshow(temp,vmin=mi,vmax=ma,cmap=color_map) plt.setp(ax,xticks=[],yticks=[]) if dd==0: # component number ax.text(0.02, 0.8, str(ll), verticalalignment='bottom', horizontalalignment='left', transform=ax.transAxes, color='white',weight='bold', fontsize=13) #sparsity spar_str=str(np.round(np.mean(shapes[ll]>0)*100,2))+'%' ax.text(0.02, 0.02, spar_str, verticalalignment='bottom', horizontalalignment='left', transform=ax.transAxes, color='white',weight='bold', fontsize=13) mn=int(np.floor(mi)) # colorbar min value mx=int(np.ceil(ma)) # colorbar max value md=old_div((mx-mn),2) divider = make_axes_locatable(ax) cax = divider.append_axes("bottom", size="5%", pad=0.05) cb=plt.colorbar(im,cax=cax,orientation="horizontal") cb.set_ticks([mn,md,mx]) cb.set_ticklabels([mn,md,mx]) #L^p for p in range(2,2,2): Lp=old_div((np.sum(shapes[ll]**p))**(old_div(1,float(p))),np.sum(shapes[ll])) Lp_str=str(np.round(Lp*100,2))+'%' #'L'+str(p)+'='+ ax.text(0.02+p*0.15, 0.02, Lp_str, verticalalignment='bottom', horizontalalignment='left', transform=ax.transAxes, color='yellow',weight='bold', fontsize=13) plt.subplots_adjust(left, bottom, right, top, wspace, hspace) index+=1 if (ll%ComponentsInFig==(ComponentsInFig-1)) or ll==L+adaptBias-1: if save_plot==True: pp.savefig(fig) index=0 pp.close() if close_figs: plt.close('all') #for ll in range(L+adaptBias): # print 'Sparsity=',np.mean(shapes[ll]>0) #%% ###### Plot Individual neurons' shape projection with clustering a=ceil(sqrt(L+adaptBias)) b=ceil(old_div((L+adaptBias),a)) if plot_clustered_shape: from sklearn.cluster import spectral_clustering pp = PdfPages(Results_folder + 'ClusteredShapes'+resultsName+'.pdf') figs=[] for dd in range(len(shapes[0].shape)): figs.append(plt.figure(figsize=(18 , 11))) for ll in range(L): ind=np.reshape(shapes[ll],(1,)+tuple(dims[1:]))>0 temp=data[np.repeat(ind,dims[0],axis=0)].reshape(dims[0],-1) delta=1 #affinity trasnformation parameter clust=3 #number of cluster similarity=np.exp(old_div(-np.corrcoef(temp.T),delta)) labels = spectral_clustering(similarity, n_clusters=clust, eigen_solver='arpack') ind2=np.array(np.nonzero(ind.reshape(-1))).reshape(-1) temp_shape=np.repeat(np.zeros_like(shapes[ll]).reshape(-1,1),clust,axis=1) for cc in range(clust): temp_shape[ind2[labels==cc],cc]=1 temp_shape=temp_shape.reshape(tuple(dims[1:])+(clust,)) for dd in range(len(shapes[0].shape)): current_fig=figs[dd] ax = current_fig.add_subplot(a,b,ll+1) if dd==2: temp_shape=np.transpose(temp_shape,axes=[1,0,2,3]) ax.imshow(temp_shape.max(dd)) plt.setp(ax,xticks=[],yticks=[]) plt.subplots_adjust(left, bottom, right, top, wspace, hspace) if save_plot==True: for dd in range(len(shapes[0].shape)): current_fig=figs[dd] pp.savefig(current_fig) pp.close() if close_figs: plt.close('all') #%% ##### Video Shapes if video_shapes: components=list(range(min(asarray([C,L])))) C=len(components) if restrict_support==True: shape_support=shapes[components[0]]>0 for cc in range(C): shape_support=np.logical_or(shape_support,shapes[components[cc]]>0) detrended_data=shape_support.reshape((1,)+tuple(dims[1:]))*detrended_data fig = plt.figure(figsize=(16,7)) mi = 0 ma = max(data)*scale #mi2 = 0 #ma2 = max(shapes[ll])*max(activity[ll]) ii=0 #import colormaps as cmaps #cmap=cmaps.viridis cmap=color_map a=3 b=1+C ax1 = plt.subplot(a,b,1) im1 = ax1.imshow(data[ii].max(0), vmin=mi, vmax=ma,cmap=cmap) title=ax1.set_title('Data') #plt.colorbar(im1) ax2=[] ax4=[] ax6=[] im2=[] im4=[] im6=[] for cc in range(C): ax2.append(plt.subplot(a,b,2+cc)) comp=shapes[components[cc]].max(0)*activity_NonNegative[components[cc],ii] ma2=max(shapes[components[cc]].max(0))*max(activity_NonNegative[components[cc]])*scale im2.append(ax2[cc].imshow(comp,vmin=0,vmax=ma2,cmap=cmap)) #ax2[0].set_title('Shape') # plt.colorbar(im2) ax3 = plt.subplot(a,b,1+b) im3 = ax3.imshow(data[ii].max(1), vmin=mi, vmax=ma,cmap=cmap) #plt.colorbar(im3) for cc in range(C): ax4.append(plt.subplot(a,b,2+b+cc)) comp=shapes[components[cc]].max(1)*activity_NonNegative[components[cc],ii] ma2=max(shapes[components[cc]].max(1))*max(activity_NonNegative[components[cc]])*scale im4.append(ax4[cc].imshow(comp,vmin=0,vmax=ma2,cmap=cmap)) #plt.colorbar(im4) ax5 = plt.subplot(a,b,1+2*b) im5 = ax5.imshow(np.transpose(detrended_data[ii].max(2)), vmin=mi, vmax=ma,cmap=cmap) #plt.colorbar(im5) for cc in range(C): ax6.append(plt.subplot(a,b,2+2*b+cc)) comp=np.transpose(shapes[components[cc]].max(2))*activity_NonNegative[components[cc],ii] ma2=max(shapes[components[cc]].max(2))*max(activity_NonNegative[components[cc]])*scale im6.append(ax6[cc].imshow(comp,vmin=0,vmax=ma2,cmap=cmap)) #plt.colorbar(im6) fig.tight_layout() ComponentsActive=np.array([]) for cc in range(C): ComponentsActive=np.append(ComponentsActive,np.nonzero(activity_NonNegative[components[cc]])) ComponentsActive=np.unique(ComponentsActive) def update(tt): ii=ComponentsActive[tt] im1.set_data(data[ii].max(0)) im3.set_data(data[ii].max(1)) im5.set_data(np.transpose(data[ii].max(2))) for cc in range(C): im2[cc].set_data(shapes[components[cc]].max(0)*activity_NonNegative[components[cc],ii]) im4[cc].set_data(shapes[components[cc]].max(1)*activity_NonNegative[components[cc],ii]) im6[cc].set_data(np.transpose(shapes[components[cc]].max(2))*activity_NonNegative[components[cc],ii]) title.set_text('Data, time = %.1f' % ii) if save_video==True: writer = animation.writers['ffmpeg'](fps=10) ani = animation.FuncAnimation(fig, update, frames=len(ComponentsActive), blit=True, repeat=False) if restrict_support==True: ani.save(Results_folder + 'Shapes_Restricted'+resultsName+'.mp4',dpi=dpi,writer=writer) else: ani.save(Results_folder + 'Shapes_'+resultsName+'.mp4',dpi=dpi,writer=writer) else: ani = animation.FuncAnimation(fig, update, frames=len(ComponentsActive), blit=True, repeat=False) plt.show() #%% ##### Plot denoised projection - Results if plot_residual_projections==True: dims=data.shape cmap=color_map pic_residual=percentile(residual, 95, axis=0) pic_denoised = max_intensity(denoised_data, axis=0) pic_data=percentile(data, 95, axis=0) left = 0.05 # the left side of the subplots of the figure right = 0.95 # the right side of the subplots of the figure bottom = 0.05 # the bottom of the subplots of the figure top = 0.95 # the top of the subplots of the figure wspace = 0.05 # the amount of width reserved for blank space between subplots hspace = 0.05 # the amount of height reserved for white space between subplots fig1=plt.figure(figsize=(11,18)) mi=min(pic_data) ma=max(pic_data) ax = plt.subplot(311) im=ax.imshow(pic_data.max(0),vmin=mi,vmax=ma,cmap=cmap) ax.set_title('Data') plt.colorbar(im) plt.setp(ax,xticks=[],yticks=[]) ax2 = plt.subplot(312) im2=ax2.imshow(max_intensity(pic_denoised,0),interpolation='None') ax2.set_title('Denoised') plt.setp(ax,xticks=[],yticks=[]) plt.colorbar(im2) ax3 = plt.subplot(313) im3=ax3.imshow(pic_residual.max(0),cmap=cmap) ax3.set_title('Residual') plt.setp(ax,xticks=[],yticks=[]) plt.colorbar(im3) plt.subplots_adjust(left, bottom, right, top, wspace, hspace) fig2=plt.figure(figsize=(11,18)) mi=min(pic_data) ma=max(pic_data) ax = plt.subplot(311) im=ax.imshow(pic_data.max(1),vmin=mi,vmax=ma,cmap=cmap) ax.set_title('Data') plt.colorbar(im) plt.setp(ax,xticks=[],yticks=[]) ax2 = plt.subplot(312) im2=ax2.imshow(max_intensity(pic_denoised,1),interpolation='None') ax2.set_title('Denoised') plt.colorbar(im2) plt.setp(ax,xticks=[],yticks=[]) ax3 = plt.subplot(313) im3=ax3.imshow(pic_residual.max(1),cmap=cmap) ax3.set_title('Residual') plt.colorbar(im3) plt.setp(ax,xticks=[],yticks=[]) plt.subplots_adjust(left, bottom, right, top, wspace, hspace) fig3=plt.figure(figsize=(11,18)) mi=min(pic_data) ma=max(pic_data) ax = plt.subplot(311) im=ax.imshow(pic_data.max(2).T,vmin=mi,vmax=ma,cmap=cmap) ax.set_title('Data') plt.colorbar(im) plt.setp(ax,xticks=[],yticks=[]) ax2 = plt.subplot(312) im2=ax2.imshow(np.transpose(max_intensity(pic_denoised,2),[1,0,2]),interpolation='None') ax2.set_title('denoised') plt.setp(ax,xticks=[],yticks=[]) plt.colorbar(im2) ax3 = plt.subplot(313) im3=ax3.imshow(np.transpose(pic_residual.max(2)),cmap=cmap) ax3.set_title('Residual') plt.colorbar(im3) plt.setp(ax,xticks=[],yticks=[]) plt.subplots_adjust(left, bottom, right, top, wspace, hspace) if save_plot==True: pp = PdfPages(Results_folder + 'Data_Denoised_Residual_Projections'+resultsName+'.pdf') pp.savefig(fig1) pp.savefig(fig2) pp.savefig(fig3) pp.close() #fig = plt.figure() #plt.plot(MSE_array) #plt.xlabel('Iteration') #plt.ylabel('MSE') #plt.show() #%% ##### Plot denoised slices - Results # z_slices=[0,2,4,6,8] #which z slices to look at slice plots/videos z_slices=list(range(dims[min_dim+1])) #which z slices to look at slice plots/videos D=len(z_slices) if plot_residual_slices==True: dims=data.shape cmap=color_map pic_residual=percentile(residual, 95, axis=0) pic_denoised = max_intensity(denoised_data, axis=0) pic_data=percentile(data, 95, axis=0) a=3 #number of rows fig1=plt.figure(figsize=(18,11)) mi=min(pic_data) ma=max(pic_data) for kk in range(D): ax2 = plt.subplot(a,D,kk+1) temp=np.squeeze(np.take(pic_denoised,(z_slices[kk],),axis=min_dim)) im2=ax2.imshow(temp,interpolation='None') ax2.set_title('Denoised') plt.setp(ax2,xticks=[],yticks=[]) plt.colorbar(im2) ax = plt.subplot(a,D,kk+D+1) temp=np.squeeze(np.take(pic_data,(z_slices[kk],),axis=min_dim)) im=ax.imshow(temp,vmin=mi,vmax=ma,cmap=cmap) ax.set_title('Data') plt.colorbar(im) plt.setp(ax,xticks=[],yticks=[]) ax3 = plt.subplot(a,D,kk+2*D+1) temp=np.squeeze(np.take(pic_residual,(z_slices[kk],),axis=min_dim)) im3=ax3.imshow(temp,cmap=cmap) ax3.set_title('Residual') plt.setp(ax3,xticks=[],yticks=[]) plt.colorbar(im3) plt.subplots_adjust(left, bottom, right, top, wspace, hspace) if save_plot==True: pp = PdfPages(Results_folder + 'Data_Denoised_Residual_Slice_'+resultsName+'.pdf') pp.savefig(fig1) pp.close() #fig = plt.figure() #plt.plot(MSE_array) #plt.xlabel('Iteration') #plt.ylabel('MSE') #plt.show() #%% ##### 2D Video Residual if video_residual_2D: fig = plt.figure(figsize=(16,7)) mi = 0 ma = np.percentile(data,satuartion_percentile) mi3 = 0 ma3 = old_div(ma,np.max([np.floor(old_div(ma,np.percentile(residual[residual>0],satuartion_percentile))),1])) ii=0 #import colormaps as cmaps #cmap=cmaps.viridis cmap=color_map a=1 b=3 im_array=[] temp=np.shape(data[ii]) ax1 = plt.subplot(a,b,1) pic=denoised_data[ii] im_array += [ax1.imshow(pic,interpolation='None')] ax1.set_title('Denoised') plt.setp(ax1,xticks=[],yticks=[]) ax2 = plt.subplot(a,b,2) pic=data[ii] im_array += [ax2.imshow(pic, vmin=mi, vmax=ma,cmap=cmap)] title=ax2.set_title('Data') plt.setp(ax2,xticks=[],yticks=[]) divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im_array[-1], cax=cax2) ax3 = plt.subplot(a,b,3) pic=residual[ii] im_array += [ax3.imshow(pic, vmin=mi3, vmax=ma3,cmap=cmap)] ax3.set_title('Residual x' + '%.1f' % (old_div(ma,ma3))) plt.setp(ax3,xticks=[],yticks=[]) divider = make_axes_locatable(ax3) cax3 = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im_array[-1], cax=cax3) # fig.tight_layout() plt.subplots_adjust(left, bottom, right, top, wspace, hspace) def update(ii): im_array[0].set_data(denoised_data[ii]) im_array[1].set_data(data[ii]) im_array[2].set_data(residual[ii]) if frame_rate!=[]: title.set_text('Data, time = %.2f sec' % (old_div(ii,frame_rate))) else: title.set_text('Data, time = %.1f' % ii) if save_video==True: writer = animation.writers['ffmpeg'](fps=10) ani = animation.FuncAnimation(fig, update, frames=len(data), blit=False, repeat=False) ani.save(Results_folder + 'Data_Denoised_Residual_2D_' +resultsName+'.avi',dpi=dpi,writer=writer) else: ani = animation.FuncAnimation(fig, update, frames=len(data), blit=False, repeat=False) plt.show() #%% ##### Video Projections Residual if video_residual: fig = plt.figure(figsize=(16,7)) mi = 0 ma = max(data)*scale mi3 = 0 ma3 = max(residual)*scale ii=0 #import colormaps as cmaps #cmap=cmaps.viridis cmap=color_map spatial_dims_ind=list(range(len(dims)-1)) D=len(spatial_dims_ind) a=D b=3 im_array=[] transpose_flags=[] for kk in range(D): transpose_flags+= [False] temp=np.shape(data[ii].max(spatial_dims_ind[kk])) if temp[0]>temp[1]: transpose_flags[kk]=True for kk in range(D): ax1 = plt.subplot(a,b,D*kk+1) if transpose_flags[kk]==False: pic=max_intensity(denoised_data[ii],spatial_dims_ind[kk]) else: pic=np.transpose(max_intensity(denoised_data[ii],spatial_dims_ind[kk]),[1,0,2]) im_array += [ax1.imshow(pic,interpolation='None')] ax1.set_title('Denoised') plt.colorbar(im_array[-1]) plt.setp(ax1,xticks=[],yticks=[]) ax2 = plt.subplot(a,b,D*kk+2) if transpose_flags[kk]==False: pic=data[ii].max(spatial_dims_ind[kk]) else: pic=np.transpose(data[ii].max(spatial_dims_ind[kk])) im_array += [ax2.imshow(pic, vmin=mi, vmax=ma,cmap=cmap)] title=ax2.set_title('Data') plt.colorbar(im_array[-1]) plt.setp(ax2,xticks=[],yticks=[]) ax3 = plt.subplot(a,b,D*kk+3) if transpose_flags[kk]==False: pic=residual[ii].max(spatial_dims_ind[kk]) else: pic=np.transpose(residual[ii].max(spatial_dims_ind[kk])) im_array += [ax3.imshow(pic, vmin=mi3, vmax=ma3,cmap=cmap)] ax3.set_title('Residual') plt.colorbar(im_array[-1]) plt.setp(ax3,xticks=[],yticks=[]) # fig.tight_layout() plt.subplots_adjust(left, bottom, right, top, wspace, hspace) def update(ii): for kk in range(D): if transpose_flags[kk]==False: im_array[kk*D].set_data(max_intensity(denoised_data[ii],spatial_dims_ind[kk])) im_array[kk*D+1].set_data(data[ii].max(spatial_dims_ind[kk])) im_array[kk*D+2].set_data(residual[ii].max(spatial_dims_ind[kk])) else: im_array[kk*D].set_data(np.transpose(max_intensity(denoised_data[ii],spatial_dims_ind[kk]),[1,0,2])) im_array[kk*D+1].set_data(np.transpose(data[ii].max(spatial_dims_ind[kk]))) im_array[kk*D+2].set_data(np.transpose(residual[ii].max(spatial_dims_ind[kk]))) title.set_text('Data, time = %.1f' % ii) if save_video==True: writer = animation.writers['ffmpeg'](fps=10) ani = animation.FuncAnimation(fig, update, frames=len(data), blit=False, repeat=False) ani.save(Results_folder + 'Data_Denoised_Residual_Projections'+resultsName+'.mp4',dpi=dpi,writer=writer) else: ani = animation.FuncAnimation(fig, update, frames=len(data), blit=False, repeat=False) plt.show() #%% ##### Video Slices Residual # z_slices=[0,2,4,6,8] #which z slices to look at slice plots/videos z_slices=list(range(dims[min_dim+1])) #which z slices to look at slice plots/videos if video_slices: fig = plt.figure(figsize=(16,7)) mi = 0 ma = np.percentil(data[data>0],satuartion_percentile) mi3 = 0 ma3 = np.percentil(data[data>0],satuartion_percentile) ii=0 #import colormaps as cmaps #cmap=cmaps.viridis cmap=color_map a=3 D=len(z_slices) #number of spatial dimensions im_array=[] transpose_flag= True for kk in range(D): ax1 = plt.subplot(a,D,kk+1) temp=np.squeeze(np.take(denoised_data[ii],(z_slices[kk],),axis=min_dim)) if transpose_flag==False: pic=temp else: pic=np.transpose(temp,[1,0,2]) im_array += [ax1.imshow(pic,interpolation='None')] ax1.set_title('Denoised, z='+ str(z_slices[kk]+1)) plt.colorbar(im_array[-1]) plt.setp(ax1,xticks=[],yticks=[]) ax2 = plt.subplot(a,D,kk+D+1) temp=np.squeeze(np.take(data[ii],(z_slices[kk],),axis=min_dim)) if transpose_flag==False: pic=temp else: pic=np.transpose(temp) im_array += [ax2.imshow(pic, vmin=mi, vmax=ma,cmap=cmap)] title=ax2.set_title('Data') plt.colorbar(im_array[-1]) plt.setp(ax2,xticks=[],yticks=[]) ax3 = plt.subplot(a,D,kk+2*D+1) temp=np.squeeze(np.take(residual[ii],(z_slices[kk],),axis=min_dim)) if transpose_flag==False: pic=temp else: pic=np.transpose(temp) im_array += [ax3.imshow(pic, vmin=mi3, vmax=ma3,cmap=cmap)] ax3.set_title('Residual') plt.colorbar(im_array[-1]) plt.setp(ax3,xticks=[],yticks=[]) # fig.tight_layout() plt.subplots_adjust(left, bottom, right, top, wspace, hspace) def update(ii): for kk in range(D): temp1=np.squeeze(np.take(denoised_data[ii],(z_slices[kk],),axis=min_dim)) temp2=np.squeeze(np.take(data[ii],(z_slices[kk],),axis=min_dim)) temp3=np.squeeze(np.take(residual[ii],(z_slices[kk],),axis=min_dim)) if transpose_flag==False: im_array[a*kk].set_data(temp1) im_array[a*kk+1].set_data(temp2) im_array[a*kk+2].set_data(temp3) else: im_array[a*kk].set_data(np.transpose(temp1,[1,0,2])) im_array[a*kk+1].set_data(np.transpose(temp2)) im_array[a*kk+2].set_data(np.transpose(temp3)) title.set_text('Data, time = %.1f' % ii) if save_video==True: writer = animation.writers['ffmpeg'](fps=10) ani = animation.FuncAnimation(fig, update, frames=len(data), blit=False, repeat=False) ani.save(Results_folder + 'Data_Denoised_Residual_Slices'+resultsName+'.mp4',dpi=dpi,writer=writer) else: ani = animation.FuncAnimation(fig, update, frames=len(data), blit=False, repeat=False) plt.show()
def LocalNMF(data, centers, sig, NonNegative=True,FinalNonNegative=True,verbose=False,adaptBias=True,TargetAreaRatio=[],estimateNoise=False, PositiveError=False,MedianFilt=False,Connected=False,FixSupport=False, WaterShed=False,SmoothBkg=False,FineTune=True,Deconvolve=False, SigmaMask=[],updateLambdaIntervals=2,updateRhoIntervals=2,addComponentsIntervals=1,bkg_per=20,SigmaBlur=[], iters=10,iters0=[30], mbs=[1], ds=1,lam1_s=0,lam1_t=0,lam2_s=0,lam2_t=0): """ Parameters ---------- data : array, shape (T, X, Y[, Z]) block of the data centers : array, shape (L, D) L centers of suspected neurons where D is spatial dimension (2 or 3) sig : array, shape (D,) size of the gaussian kernel in different spatial directions NonNegative : boolean if True, neurons activity should be considered as non-negative FinalNonNegative : boolean if False, last activity iteration is done without non-negativity constraint, even if NonNegative==True verbose : boolean print progress and record MSE if true (about 2x slower) adaptBias : boolean subtract rank 1 estimate of bias (background) TargetAreaRatio : list of length 2 Lower and upper bounds on sparsity of non-background components estimateNoise : boolean estimate noise variance and use it determine if to add components, and to modify sparsity by affecting lam1_s (does not work very well) PositiveError : boolean do not allow pixels in which the residual (summed over time) becomes negative, by increasing lam1_s in these pixels MedianFilt : boolean do median filter of spatial components Connected: boolean impose connectedness of spatial component by keeping only the largest non-zero connected component in each iteration of HALS WaterShed: boolean impose that each spatial component has a single watershed region SmoothBkg: boolean Remove local peaks from background component FixSupport : boolean do not allow spatial components to be non-zero where sub-sampled spatial components are zero FineTune : boolean fine tune main iterations on full data, if not, use (last) downsampled data Deconvolve : boolean Deconvolve activity to get smoothed (denoised) calcium trace. This is done only on the main itreations, and if FineTune=True SigmaMask : scalar or empty if not [], then update masks so that they are SigmaMasks around non-zero support of shapes SigmaBlur : scalar if not [], then de-blur spatial components using Gaussian Kernel of this width updateLambdaIntervals : int update lam1_s every this number of HALS iterations, to match contraints updateRhoIntervals : int decrease rho, update rate of lam1_s, every this number of updateLambdaIntervals HALS iterations (only active during main iterations) addComponentsIntervals : int add new component, if possible, every this number of updateLambdaIntervals HALS iterations (only active during sub-sampled iterations) bkg_per : float the background is intialized at this height (percentrilce image) iters : int number of final iterations on whole data iters0 : list numbers of initial iterations on subset mbs : list minibatchsizes for temporal downsampling ds : int or list factor for spatial downsampling, can be an integer or a list of the size of spatial dimensions lam1_s : float L_1 regularization constant for sparsity of shapes lam2_s : float L_2 regularization constant for sparsity of shapes lam_t : float L_1 regularization constant for sparsity of activity lam2_t : float L_2 regularization constant for sparsity of activity Returns ------- MSE_array : list (empty if verbose is False) Mean square error during algorithm operation shapes : array, shape (L+adaptBias, X, Y (,Z)) the neuronal shape vectors (empty if no components found) activity : array, shape (L+adaptBias, T) the neuronal activity for each shape (empty if no components found) boxes : array, shape (L, D, 2) edges of the boxes in which each neuronal shapes lie (empty if no components found) """ # Catch Errors if ds!=1 and SigmaBlur!=[]: raise NameError('case ds!=1 and SigmaBlur!=[] no yet written in NMF code') # Initialize Parameters dims = data.shape # data dimensions D = len(dims) #number of data dimensions R = 3 * asarray(sig) # size of bounding box is 3 times size of neuron L = len(centers) # number of components (not including background) inner_iterations=10 # number of iterations in inners loops shapes = [] #array of spatial components mask = [] # binary array, support of spatial components boxes = zeros((L, D - 1, 2), dtype=int) #initial support of spatial components MSE_array = [] #CNMF residual error mb = mbs[0] if iters0[0] > 0 else 1 activity = zeros((L, old_div(dims[0], mb))) #array of temporal components lam1_s0=np.copy(lam1_s) #intial spatial sparsity (l1) parameters if TargetAreaRatio!=[]: if TargetAreaRatio[0]>TargetAreaRatio[1]: print('WARNING - TargetAreaRatio[0]>TargetAreaRatio[1] !!!') if iters0[0] == 0: ds = 1 ### Initialize shapes, activity, and residual ### data0,dims0=DownScale(data,mb,ds) #downscaled data and dimensions if isinstance(ds,int): ds=ds*np.ones(D-1) if D == 4: #downscale activity activity = data0[:, list(map(int, old_div(centers[:, 0], ds[0]))), list(map(int, old_div(centers[:, 1], ds[1]))), list(map(int, old_div(centers[:, 2], ds[2])))].T else: activity = data0[:, list(map(int, old_div(centers[:, 0], ds[0]))), list(map(int, old_div(centers[:, 1], ds[1])))].T data0 = data0.reshape(dims0[0], -1) #reshape data0 to more convient timexspace form Energy0=np.sum(data0**2,axis=0) #data0 energy per pixel data0sum=np.sum(data0,axis=0) # for sign check later data = data.astype('float').reshape(dims[0], -1) #reshape data to more convient timexspace form datasum=np.sum(data,axis=0)# for sign check later # float is faster than float32, presumable float32 gets converted later on # to float again and again Energy=np.sum((data**2),axis=0) #data energy per pixel # extract shapes and activity from given centers for ll in range(L): boxes[ll] = GetBox(old_div(centers[ll], ds), old_div(R, ds), dims0[1:]) temp = zeros(dims0[1:]) temp[[slice(*a) for a in boxes[ll]]]=1 mask += np.where(temp.ravel()) temp = [old_div((arange(int(old_div(dims[i + 1], ds[i]))) -int( old_div(centers[ll][i], ds[i]))) ** 2, (2 * (old_div(sig[i], ds[i])) ** 2)) for i in range(D - 1)] temp = exp(-sum(ix_(*temp))) temp.shape = (1,) + dims0[1:] temp = RegionCut(temp, boxes[ll]) shapes.append(temp[0]) S = zeros((L + adaptBias, prod(dims0[1:]))) #shape component for ll in range(L): S[ll] = RegionAdd( zeros((1,) + dims0[1:]), shapes[ll].reshape(1, -1), boxes[ll]).ravel() if adaptBias: # Initialize background as bkg_per percentile S[-1] = percentile(data0, bkg_per, 0) activity = np.r_[activity, ones((1, dims0[0]))] lam1_s=lam1_s0*np.ones_like(S)*mbs[0] #intialize sparsity parameters ### Get shape estimates on subset of data ### if iters0[0] > 0: for it in range(len(iters0)): if estimateNoise: sn_target,sn_std= GetSnPSDArray(data0)#target noise level else: sn_target=np.zeros(prod(dims0[1:])) sn_std=sn_target MSE_target = np.mean(sn_target**2) ES=ExponentialSearch(lam1_s) #object to update sparsity parameters lam1_s=ES.lam for kk in range(iters0[it]): # update sparisty parameters if kk%updateLambdaIntervals==0: sn=old_div(np.sqrt(Energy0-2*np.sum(np.dot(activity,data0)*S,axis=0)+np.sum(np.dot(np.dot(activity,activity.T),S)*S,axis=0)),dims0[0]) # efficient way to calcuate MSE per pixel delta_sn=sn-sn_target # noise margin signcheck=(data0sum-np.dot(np.sum(activity.T,axis=0),S))<0 if PositiveError: #obsolete delta_sn[signcheck]=-float("inf") # residual should not have negative pixels, so we increase lambda for these pixels if len(S)==0: spars=0 else: spars=np.mean(S>0,axis=1) temp=repeat(delta_sn.reshape(1,-1),L+adaptBias,axis=0) if TargetAreaRatio==[]: cond_decrease=temp>sn_std cond_increase=temp<-sn_std else: if adaptBias: spars[-1]=old_div((TargetAreaRatio[1]+TargetAreaRatio[0]),2) # ignore sparsity target for background (bias) component temp2=repeat(spars.reshape(-1,1),len(S[0]),axis=1) cond_increase=np.logical_or(temp2>TargetAreaRatio[1],temp<-sn_std) cond_decrease=np.logical_and(temp2<TargetAreaRatio[0],temp>sn_std) ES.update(cond_decrease,cond_increase) lam1_s=ES.lam #Print residual error and additional information MSE = np.mean(sn**2) if verbose and L>0: print(' MSE = {0:.6f}, Target MSE={1:.6f},Sparsity={2:.4f},lam1_s={3:.6f}'.format(MSE,MSE_target,np.mean(spars[:L]),np.mean(lam1_s))) #add a new component if (kk%addComponentsIntervals==0) and (kk!=iters0[it]-1): delta_sn[signcheck]=-float("inf") # residual should not have negative pixels new_cent=np.argmax(delta_sn) #should I smooth the data a bit first? MSE_std=np.mean(sn_std**2) checkNoZero= not((0 in np.sum(activity,axis=1)) and (0 in np.sum(S,axis=1))) if ((MSE-MSE_target>2*MSE_std) and checkNoZero and (delta_sn[new_cent]>sn_std[new_cent])): S, activity, mask,centers,boxes,L=addComponent(new_cent,data0,dims0,old_div(R,ds),S, activity, mask,centers,boxes,adaptBias) new_lam=lam1_s0*np.ones_like(data0[0,:]).reshape(1,-1) lam1_s=np.insert(lam1_s,0,values=new_lam,axis=0) ES=ExponentialSearch(lam1_s) #we need to restart exponential search each time we add a component #apply additional constraints/processing if SigmaBlur==[]: S = HALS4shape(data0, S, activity,mask,lam1_s,lam2_s,adaptBias,inner_iterations) else: #obsolete S=FISTA4shape(data0, S, activity,mask,lam1_s,adaptBias,SigmaBlur,dims0) if Connected==True: S=LargestConnectedComponent(S,dims0,adaptBias) if WaterShed==True: S=LargestWatershedRegion(S,dims0,adaptBias) activity = HALS4activity(data0, S, activity,NonNegative,lam1_t,lam2_t,dims0,SigmaBlur,inner_iterations) if SigmaMask!=[]: mask=GrowMasks(S,mask,boxes,dims0,adaptBias,SigmaMask) S, activity, mask,centers,boxes,ES,L=RenormalizeDeleteSort(S, activity, mask,centers,boxes,ES,adaptBias,MedianFilt) lam1_s=ES.lam if SmoothBkg==True: S=SmoothBackground(S,dims0,adaptBias,tuple(old_div(np.array(sig),np.array(ds)))) print('Subsampled iteration',kk,'it=',it,'L=',L) # use next (smaller) value for temporal downscaling if it < len(iters0) - 1: mb = mbs[it + 1] data0 = data[:len(data) / mb * mb].reshape(-1, mb, prod(dims[1:])).mean(1) if D==4: data0 = data0.reshape(len(data0), int(old_div(dims[1], ds[0])), ds[0], int(old_div(dims[2], ds[1])), ds[1], int(old_div(dims[3], ds[2])), ds[2]).mean(-1).mean(-2).mean(-3) else: data0 = data0.reshape(len(data0), int(old_div(dims[1], ds[0])), ds[0], int(old_div(dims[2], ds[1])), ds[1]).mean(-1).mean(-2) data0.shape = (len(data0), -1) activity = ones((L + adaptBias, len(data0))) * activity.mean(1).reshape(-1, 1) lam1_s=lam1_s*mbs[it+1]/mbs[it] activity = HALS4activity(data0, S, activity,NonNegative,lam1_t,lam2_t,dims0,SigmaBlur,30) S, activity, mask,centers,boxes,ES,L=RenormalizeDeleteSort(S, activity, mask,centers,boxes,ES,adaptBias,MedianFilt) lam1_s=ES.lam ### Stop adding components ### if L==0: #if no non-background components found, return empty arrays return [], [], [], [] if FineTune: ### Upscale Back to full data ## activity = ones((L + adaptBias, dims[0])) * activity.mean(1).reshape(-1, 1) data0=data dims0=dims if D==4: S = repeat(repeat(repeat(S.reshape((-1,) + dims0[1:]), ds[0], 1), ds[1], 2), ds[2], 3) lam1_s= repeat(repeat(repeat(lam1_s.reshape((-1,) + dims0[1:]), ds[0], 1), ds[1], 2), ds[2], 3) else: S = repeat(repeat(S.reshape((-1,) + dims0[1:]), ds[0], 1), ds[1], 2) lam1_s= repeat(repeat(lam1_s.reshape((-1,) + dims0[1:]), ds[0], 1), ds[1], 2) for dd in range(1,D): while S.shape[dd]<dims[dd]: shape_append=np.array(S.shape) shape_append[dd]=1 S=np.append(S,values=np.take(S,-1,axis=dd).reshape(shape_append),axis=dd) lam1_s=np.append(lam1_s,values=np.take(lam1_s,-1,axis=dd).reshape(shape_append),axis=dd) S=S.reshape(L + adaptBias, -1) lam1_s=lam1_s.reshape(L+ adaptBias,-1) for ll in range(L): boxes[ll] = GetBox(centers[ll], R, dims[1:]) temp = zeros(dims[1:]) temp[[slice(*a) for a in boxes[ll]]] = 1 mask[ll] = np.where(temp.ravel())[0] if FixSupport: #obsolete for ll in range(L): lam1_s[ll,S[ll]==0]=float("inf") ES=ExponentialSearch(lam1_s) activity = HALS4activity(data0, S, activity,NonNegative,lam1_t,lam2_t,dims0,SigmaBlur, 30) S, activity, mask,centers,boxes,ES,L=RenormalizeDeleteSort(S, activity, mask,centers,boxes,ES,adaptBias,MedianFilt) lam1_s=ES.lam if estimateNoise: sn_target,sn_std= GetSnPSDArray(data0)#target noise level else: sn_target=np.zeros(prod(dims0[1:])) sn_std=sn_target MSE_target = np.mean(sn_target**2) MSE_std=np.mean(sn_std**2) # MSE = np.mean((data0-np.dot(activity.T,S))**2) #### Main Loop #### print('starting main NMF loop') for kk in range(iters): lam1_s=ES.lam #update sparsity parameters if SigmaBlur==[]: S = HALS4shape(data0, S, activity,mask,lam1_s,lam2_s,adaptBias,inner_iterations) else: #obsolete S = FISTA4shape(data0, S, activity,mask,lam1_s,adaptBias,SigmaBlur,dims0) #apply additional constraints/processing if Connected==True: S=LargestConnectedComponent(S,dims0,adaptBias) if WaterShed==True: S=LargestWatershedRegion(S,dims0,adaptBias) if kk==iters-1: if FinalNonNegative==False: NonNegative=False activity = HALS4activity(data0, S, activity,NonNegative,lam1_t,lam2_t,dims0,SigmaBlur,inner_iterations) if FineTune and Deconvolve: for ll in range(L): if np.sum(np.abs(activity[ll])>0)>30: #make sure there is enough signal before we try to deconvolve activity[ll], _, _, _, _ = deconvolve(activity[ll], penalty=0) if SigmaMask!=[]: mask=GrowMasks(S,mask,boxes,dims0,adaptBias,SigmaMask) S, activity, mask,centers,boxes,ES,L=RenormalizeDeleteSort(S, activity, mask,centers,boxes,ES,adaptBias,MedianFilt) # Measure MSE and update sparsity parameters print('main iteration kk=',kk,'L=',L) if (kk+1)%updateLambdaIntervals==0: sn=np.sqrt(old_div((Energy-2*np.sum(np.dot(activity,data0)*S,axis=0)+np.sum(np.dot(np.dot(activity,activity.T),S)*S,axis=0)),dims0[0])) delta_sn=sn-sn_target MSE = np.mean(sn**2) signcheck=(datasum-np.dot(np.sum(activity.T,axis=0),S))<0 if PositiveError: #obsolete delta_sn[signcheck]=-float("inf") # residual should not have negative pixels, so we increase lambda for these pixels if S==[]: spars=0 else: spars=np.mean(S>0,axis=1) temp=repeat(delta_sn.reshape(1,-1),L+adaptBias,axis=0) if TargetAreaRatio==[]: cond_decrease=temp>sn_std cond_increase=temp<-sn_std else: if adaptBias: spars[-1]=old_div((TargetAreaRatio[1]+TargetAreaRatio[0]),2) # ignore sparsity target for background (bias) component temp2=repeat(spars.reshape(-1,1),len(S[0]),axis=1) cond_increase=np.logical_or(temp2>TargetAreaRatio[1],temp<-sn_std) cond_decrease=np.logical_and(temp2<TargetAreaRatio[0],temp>sn_std) ES.update(cond_decrease,cond_increase) lam1_s=ES.lam if kk<old_div(iters,3): #restart exponential search unless enough iterations have passed ES=ExponentialSearch(lam1_s) else: if not(np.any(cond_increase) or np.any(cond_decrease)): print('sparsity target reached') break if L+adaptBias>1: # if we have more then one component just keep exponitiated grad descent instead if (kk+1)%updateRhoIntervals==0: #update rho every updateRhoIntervals if we are still not converged if np.any(spars[:L]<TargetAreaRatio[0]) or np.any(spars[:L]>TargetAreaRatio[1]): ES.rho=2-old_div(1,(ES.rho)) print('rho=',ES.rho) ES=ExponentialSearch(lam1_s,rho=ES.rho) # prinst MSE and other information if verbose: print(' MSE = {0:.6f}, Target MSE={1:.6f},Sparsity={2:.4f},lam1_s={3:.6f}'.format(MSE,MSE_target,np.mean(spars[:L]),np.mean(lam1_s))) if kk == (iters - 1): print('Maximum iteration limit reached') MSE_array.append(MSE) # Some post-processing S=S.reshape((-1,) + dims[1:]) S,activity,L=PruneComponents(S,activity,L) #prune "bad" components if len(S)>1: S,activity,L=MergeComponents(S,activity,L,threshold=0.9,sig=10) #merge very similar components if not FineTune: activity = ones((L + adaptBias, dims[0])) * activity.mean(1).reshape(-1, 1) #extract activity from full data activity=HALS4activity(data, S.reshape((len(S),-1)), activity,NonNegative,lam1_t,lam2_t,dims0,SigmaBlur,iters=30) return asarray(MSE_array), S, activity # example to check code works #T = 1000 #X = 201 #Y = 101 #data = np.random.randn(T, X, Y) #centers = asarray([[40, 30]]) #data[:, 30:45, 25:33] += 2*np.sin(np.array(range(T))/200).reshape(-1,1,1)*np.ones([T,15,8]) #sig = [300, 300] # #MSE_array, shapes, activity, boxes = LocalNMF( # data, centers, sig, NonNegative=True, verbose=True,lam1_s=0.1,adaptBias=True) # # #import matplotlib.pyplot as plt #plt.imshow(shapes[0]) # #for ll in range(len(shapes)): # print np.mean(shapes[ll]>0)