Exemple #1
0
def ca_deconvolution(ddf_trace, l0=False):
    """ perform calcium image deconvolution 
	
	This function performs several calcium image 
	deconvolution approaches. Deconvolutions currently 
	supported: 
	
	1. OASIS (https://github.com/j-friedrich/OASIS)
	2. Event detection script from Peter
	3. AR-FPOP
	
	input: 
		ddf_trace: a 1d-numpy array of length n (the number of 
		time steps in the calcium trace)
		
	output: 
		a dictionary whose keys are the deconvolution method 
		used and values are a 1d-numpy array of length n with
		the estimated spikes 
	
	TODO:
	
	Add functionality for the following methods
	
	4. ML Spike 
	5. One of the supervised methods? 	
	
	"""

    out = {}

    # Method 1 OASIS (https://github.com/j-friedrich/OASIS)
    c, s, b, g, lam = deconvolve(np.double(ddf_trace), penalty=1)
    out['OASIS'] = s

    # Method 2 event detection
    yes_array, size_array = ed.get_events_derivative(ddf_trace)
    times_new, heights_new = ed.concatenate_adjacent_events(yes_array,
                                                            size_array,
                                                            delta=3)
    tmp = np.zeros_like(ddf_trace)
    tmp[times_new] = heights_new
    out['event_detection'] = tmp

    # Method 3 FastLZeroSpikeInference
    if (l0):
        import arfpop_ctypes as af

        # Some default parameters that need to be tuned!
        gam = 0.99
        penalty = 0.25
        constraint = False

        ar_fit = af.arfpop(ddf_trace, gam, penalty, constraint)
        out['arfpop'] = ar_fit['pos_spike_mag']

    return out
Exemple #2
0
def ca_deconvolution(ddf_trace):
    """ perform calcium image deconvolution 
	
	This function performs several calcium image 
	deconvolution approaches. Deconvolutions currently 
	supported: 
	
	1. OASIS (https://github.com/j-friedrich/OASIS)
	2. Event detection script from Peter

	
	input: 
		ddf_trace: a 1d-numpy array of length n (the number of 
		time steps in the calcium trace)
		
	output: 
		a dictionary whose keys are the deconvolution method 
		used and values are a 1d-numpy array of length n with
		the estimated spikes 
	
	TODO:
	
	Add functionality for the following methods
	
	3. AR-FPOP
	4. ML Spike 
	5. One of the supervised methods? 	
	
	"""

    out = {}

    # Method 1 OASIS (https://github.com/j-friedrich/OASIS)
    c, s, b, g, lam = deconvolve(np.double(ddf_trace), penalty=1)
    out['OASIS'] = s

    # Method 2 event detection
    yes_array, size_array = ed.get_events_derivative(ddf_trace)
    times_new, heights_new = ed.concatenate_adjacent_events(yes_array,
                                                            size_array,
                                                            delta=3)
    tmp = np.zeros_like(ddf_trace)
    tmp[times_new] = heights_new
    out['event_detection'] = tmp

    return out
def PlotAll(SaveNames,params):
    from numpy import  min, max, percentile,asarray,ceil,sqrt
    import numpy as np
    import sys
    from scipy.signal import welch
    from pylab import load
    import matplotlib.pyplot as plt
    import matplotlib.animation as animation
    from matplotlib.backends.backend_pdf import PdfPages
#    from scipy.ndimage.measurements import label    
    from AuxilaryFunctions import GetRandColors, max_intensity,SuperVoxelize,GetData,PruneComponents,SplitComponents,ThresholdShapes,MergeComponents,ThresholdData,make_sure_path_exists
#    from BlockLocalNMF_AuxilaryFunctions import HALS4activity
    from mpl_toolkits.axes_grid1 import make_axes_locatable    
    
    # makse sure relevant folders exist, and add to path    
    Results_folder='Results/'
    make_sure_path_exists(Results_folder)
        
    OASIS_path='OASIS/'   
    make_sure_path_exists(OASIS_path)
    sys.path.append(OASIS_path)
    from functions import deconvolve

    ## plotting params 
    # what to plot 
    plot_activities=True
    plot_activities_PSD=False
    plot_shapes_projections=True
    plot_shapes_slices=False
    plot_activityCorrs=False
    plot_clustered_shape=False
    plot_residual_slices=False
    plot_residual_projections=False
    # videos to generate
    video_shapes=False
    video_residual=True
    video_slices=False
    # what to save
    save_video=True
    save_plot=True
    close_figs=True#close all figs right after saving (to avoid memory overload)
    # PostProcessing   
    Split=False   
    Threshold=False   #threshold shapes in the end and keep only connected components
    Prune=False # Remove "Bad" components (where bad is defined within SplitComponent fucntion)
    Merge=True # Merge highly correlated nearby components
    FineTune=False # SHould we fine tune activity after post-processing? (mainly after merging)
    IncludeBackground=False #should we include the background as an extracted component?
    
    # how to plot
    detrend=True #should we detrend the data (remove background component)?
    scale=2 #scale colormap to enhance colors
    satuartion_percentile=96 #saturate colormap ont this percentile, when ma=percentile is used
    dpi=200 #for videos
    restrict_support=True #in shape video, zero out data outside support of shapes
    C=4 #number of components to show in shape videos (if larger then number of shapes L, then we automatically set C=L)
    color_map='gray' #'gnuplot'
    frame_rate=10.0 #Hz
    
    # Fetch experimental 3D data 
    data=GetData(params.data_name)
    if params.SuperVoxelize==True:
        data=SuperVoxelize(data)
    
    if params.ThresholdData==True:
        data=ThresholdData(data)        
        
    dims=np.shape(data)
    
    if len(dims)<4:
        plot_Shapes2D=False
        video_residual_2D=False
        if plot_shapes_projections or plot_shapes_slices:
            plot_Shapes2D=True
        
        if video_residual==True:
            video_residual_2D=True
    
        plot_shapes_projections=False
        plot_shapes_slices=False
        plot_activityCorrs=False
        plot_clustered_shape=False
        plot_residual_slices=False
        plot_residual_projections=False
        video_shapes=False
        video_residual=False
        video_slices=False
        print('2D data, ignoring 3D plots/video options')
    else:
        plot_Shapes2D=False
        video_residual_2D=False

    
    min_dim=np.argmin(dims[1:])
    denoised_data=0    
    detrended_data=data
    
    for rep in range(len(SaveNames)): 
        resultsName=SaveNames[rep]
        try:
            results=load('NMF_Results/'+SaveNames[rep])
        except IOError:
            if rep==0:
                print('results file not found!!')              
            else:
                break            
        SS=results['shapes']
        AA=results['activity']

        if rep>=params.Background_num:
            adaptBias=False
        else:
            adaptBias=True
            
        if IncludeBackground==True:
            adaptBias=False        
               
        L=len(AA)-adaptBias
        if L==0: #Stop if we encounter a file with zero components
            break
        S=SS[:-adaptBias]
        b=SS[L:(L+adaptBias)]
        A=AA[:-adaptBias]
        f=AA[L:(L+adaptBias)]
        if rep==0:
            shapes=S
            activity=A
            background_shapes=b
            background_activity=f
        else:
            shapes=np.append(shapes,S,axis=0)
            activity=np.append(activity,A,axis=0) 
            background_shapes=np.append(background_shapes,b,axis=0)
            background_activity=np.append(background_activity,f,axis=0) 
        
    L=len(shapes)
    adaptBias=0
    
    if Split==True:
        shapes,activity,L,all_local_max=SplitComponents(shapes,activity,adaptBias)   
    
    if Merge==True:
        shapes,activity,L=MergeComponents(shapes,activity,L,threshold=0.7,sig=10)
        
    if Prune==True:
#           deleted_indices=[5,9,11,14,15,17,24]+range(25,36)
        shapes,activity,L=PruneComponents(shapes,activity,L,params.TargetAreaRatio)
    
    activity_NonNegative=np.copy(activity)
    activity_NonNegative[activity_NonNegative<0]=0
    activity_noisy=np.copy(activity_NonNegative)
    if FineTune==True:
        for ll in range(L):
            activity[ll], spikes, baseline, g, lam = deconvolve(activity_NonNegative[ll],optimize_g=10,penalty=0)
#            activity,background_activity,S,bl,c1,sn,g,junk = update_temporal_components(data.reshape((len(data),-1)).transpose(), shapes.reshape((len(shapes),-1)).transpose(), background_shapes.reshape((len(background_shapes),-1)).transpose(), activity,background_activity,**options['temporal_params'])
        activity_noisy=np.copy(activity_NonNegative)
        activity_NonNegative=activity
    
    print(str(L)+' shapes detected')
    
    detrended_data= detrended_data - background_activity.T.dot(background_shapes.reshape((len(background_shapes), -1))).reshape(dims)        

    if Threshold==True:            
        shapes=ThresholdShapes(shapes,adaptBias,[],MaxRatio=[])
                
    if plot_residual_projections or video_shapes or video_residual or video_slices or video_residual_2D:
        colors=GetRandColors(L)
        color_shapes=np.transpose(shapes.reshape(L, -1,1)*colors,[1,0,2]) #weird transpose for tensor dot product next line
        denoised_data = denoised_data + (activity_NonNegative.T.dot(color_shapes)).reshape(tuple(dims)+(3,))   
        residual = detrended_data - activity_NonNegative.T.dot(shapes.reshape(L, -1)).reshape(dims)
    
    if detrend==True:
        data=detrended_data
 
#%% After loading loop - Normalize (colored) denoised data
    if plot_residual_projections or video_shapes or video_residual or video_slices or video_residual_2D:
#       denoised_data=denoised_data/np.max(denoised_data)
       denoised_data=old_div(denoised_data,np.percentile(denoised_data[denoised_data>0],99.5))  #%% normalize denoised data range
       denoised_data[denoised_data>1]=1           
    
    #    plt.close('all')
        
    #%% plotting params
    ComponentsInFig=20

    left  = 0.05 # the left side of the subplots of the figure
    right = 0.95   # the right side of the subplots of the figure
    bottom = 0.05   # the bottom of the subplots of the figure
    top = 0.95      # the top of the subplots of the figure
    wspace = 0.1   # the amount of width reserved for blank space between subplots
    hspace = 0.12  # the amount of height reserved for white space between subplots        
    
              
    #%% ###### Plot Individual neurons' activities
    index=0 #component display index
    sz=np.min([ComponentsInFig,L+adaptBias])
    
#    a=ceil(sqrt(sz))  
#    b=ceil(sz/a)  
    
    a=sz
    b=1
    
    if plot_activities:
        pp = PdfPages(Results_folder + 'Activities'+resultsName+'.pdf')        
        for ii in range(L+adaptBias):
            if index==0:
#                fig0=plt.figure(figsize=(dims[1] , dims[2]))
                 fig0=plt.figure(figsize=(11,18))
            ax = plt.subplot(a,b,index+1)
#            dt=1/30 # 30 Hz sample rate
            time=list(range(len(activity[ii])))
            plt.plot(time,activity_noisy[ii],linewidth=0.5,c='r')
            plt.plot(time,activity[ii],linewidth=3,c='b')
            ma=np.max([np.max(activity[ii]),np.max(activity_noisy[ii])])            
            plt.setp(ax, xticks=[],yticks=[0,ma])
            # component number
            ax.text(0.02, 0.8, str(ii),
                verticalalignment='bottom', horizontalalignment='left',
                transform=ax.transAxes,
                color='black',weight='bold', fontsize=13)
            index+=1   
            if ((ii%ComponentsInFig)==(ComponentsInFig-1)) or ii==(L+adaptBias-1):                 
                index=0
                if save_plot==True:
                    plt.subplots_adjust(left*2, bottom, right, top, wspace, hspace*2)
                    pp.savefig(fig0)    
        pp.close()
        if close_figs:
            plt.close('all')
            
    #%% Plot activities` PSDs
    index=0 #component display index
    sz=np.min([ComponentsInFig,L+adaptBias])
    a=ceil(sqrt(sz))  
    b=ceil(old_div(sz,a))  
    
    if plot_activities_PSD:
        pp = PdfPages(Results_folder + 'ActivityPSDs'+resultsName+'.pdf')        
        for ii in range(L+adaptBias):
            if index==0:
#                fig0=plt.figure(figsize=(dims[1] , dims[2]))
                 fig0=plt.figure(figsize=(11,18))
            ax = plt.subplot(a,b,index+1)
            ff, psd_activity = welch(activity[ii], nperseg=round(old_div(len(activity[ii]), 64)))
            plt.plot(ff,psd_activity,linewidth=3)
            plt.setp(ax, xticks=[],yticks=[0])
            # component number
            ax.text(0.02, 0.8, str(ii),
                verticalalignment='bottom', horizontalalignment='left',
                transform=ax.transAxes,
                color='black',weight='bold', fontsize=13)
            index+=1   
            if ((ii%ComponentsInFig)==(ComponentsInFig-1)) or ii==(L+adaptBias-1):                 
                index=0
                if save_plot==True:
                    plt.subplots_adjust(left, bottom, right, top, wspace, hspace)
                    pp.savefig(fig0)    
        pp.close()
        if close_figs:
            plt.close('all')
            

            

    #%%  2D shapes
    index=0 #component display index
    sz=np.min([ComponentsInFig,L+adaptBias])
    a=ceil(0.5*sqrt(sz))  
    b=ceil(old_div(sz,a))  
    
    if plot_Shapes2D:            
        if save_plot==True:
            pp = PdfPages(Results_folder + 'Shapes2D_'+resultsName+'.pdf')
        for ll in range(L+adaptBias):
            if index==0:
                fig=plt.figure(figsize=(18 , 11))
            ax = plt.subplot(a,b,index+1)  
            temp=shapes[ll]
            mi=0
            try:
                ma=np.percentile(temp[temp>0],satuartion_percentile)
            except IndexError:
                ma=0
            im=plt.imshow(temp,vmin=mi,vmax=ma,cmap=color_map)
            plt.setp(ax,xticks=[],yticks=[])
            mn=int(np.floor(mi))        # colorbar min value
            mx=int(np.ceil(ma))         # colorbar max value
            md=old_div((mx-mn),2)
#                divider = make_axes_locatable(ax)
#                cax = divider.append_axes("right", size="5%", pad=0.05)
#                cb=plt.colorbar(im,cax=cax)
#                    cb.set_ticks([mn,md,mx])
#                    cb.set_ticklabels([mn,md,mx])
            
            # component number
            ax.text(0.02, 0.8, str(ll),
            verticalalignment='bottom', horizontalalignment='left',
            transform=ax.transAxes,
            color='white',weight='bold', fontsize=13)
            #sparsity
            spar_str=str(np.round(np.mean(shapes[ll]>0)*100,2))+'%'
            ax.text(0.02, 0.02, spar_str,
            verticalalignment='bottom', horizontalalignment='left',
            transform=ax.transAxes,
            color='white',weight='bold', fontsize=13)
            #L^p
            for p in range(2,6,2):
                Lp=old_div((np.sum(shapes[ll]**p))**(old_div(1,float(p))),np.sum(shapes[ll]))
                Lp_str=str(np.round(Lp*100,2))+'%' #'L'+str(p)+'='+
                ax.text(0.02+p*0.2, 0.02, Lp_str,
                verticalalignment='bottom', horizontalalignment='left',
                transform=ax.transAxes,
                color='yellow',weight='bold', fontsize=13)         
            index+=1
            if (ll%ComponentsInFig==(ComponentsInFig-1)) or ll==L+adaptBias-1: 
                plt.subplots_adjust(left, bottom, right, top, wspace, hspace)
                index=0
                if save_plot==True:
                    pp.savefig(fig)            
        pp.close()
        if close_figs:
            plt.close('all')
            
    #%% Re-write plot code from here, so that each figure has only ComponentsInFig components         
    #%% ###### Plot Individual neurons' area which is correlated with their activities
    a=ceil(sqrt(L+adaptBias))
    b=ceil(old_div((L+adaptBias),a))
    
    if plot_activityCorrs:
        if save_plot==True:
            pp = PdfPages(Results_folder + 'CorrelationWithActivity'+resultsName+'.pdf')
        for dd in range(len(shapes[0].shape)):
            fig0=plt.figure(figsize=(11,18))
    
            for ii in range(L+adaptBias):
                ax = plt.subplot(a,b,ii+1)
                corr_imag=old_div(np.dot(activity[ii],np.transpose(data,[1,2,0,3])),np.sqrt(np.sum(data**2,axis=0)*np.sum(activity[ii]**2)))
                plt.imshow(np.abs(corr_imag).max(dd),cmap=color_map)
                plt.setp(ax,xticks=[],yticks=[])
            plt.subplots_adjust(left, bottom, right, top, wspace, hspace)
        
            if save_plot==True:
                pp.savefig(fig0)
        pp.close()
        if close_figs:
            plt.close('all')

    #%%  All Shapes projections
    a=ceil(sqrt(L+adaptBias))
    b=ceil(old_div((L+adaptBias),a))

    if plot_shapes_projections:
        if save_plot==True:
            pp = PdfPages(Results_folder + 'Shapes_projections'+resultsName+'.pdf')

        for dd in range(len(shapes[0].shape)):
            fig=plt.figure(figsize=(18 , 11))
            for ll in range(L+adaptBias):
                ax = plt.subplot(a,b,ll+1)  
                temp=shapes[ll].max(dd)
                if dd==2:
                    temp=temp.T
                mi=np.min(shapes[ll])
                ma=np.max(shapes[ll])
                im=plt.imshow(temp,vmin=mi,vmax=ma,cmap=color_map)
                plt.setp(ax,xticks=[],yticks=[])
                mn=int(np.floor(mi))        # colorbar min value
                mx=int(np.ceil(ma))         # colorbar max value
                md=old_div((mx-mn),2)
                divider = make_axes_locatable(ax)
                cax = divider.append_axes("right", size="5%", pad=0.05)
                cb=plt.colorbar(im,cax=cax)
#                    cb.set_ticks([mn,md,mx])
#                    cb.set_ticklabels([mn,md,mx])
                
                # component number
                ax.text(0.02, 0.8, str(ll),
                verticalalignment='bottom', horizontalalignment='left',
                transform=ax.transAxes,
                color='white',weight='bold', fontsize=13)
#                    #sparsity
                spar_str=str(np.round(np.mean(shapes[ll]>0)*100,2))+'%'
                ax.text(0.02, 0.02, spar_str,
                verticalalignment='bottom', horizontalalignment='left',
                transform=ax.transAxes,
                color='white',weight='bold', fontsize=13)
#                    #L^p
#                    for p in range(2,2,2):
#                        Lp=(np.sum(shapes[ll]**p))**(1/float(p))/np.sum(shapes[ll])
#                        Lp_str=str(np.round(Lp*100,2))+'%' #'L'+str(p)+'='+
#                        ax.text(0.02+p*0.2, 0.02, Lp_str,
#                        verticalalignment='bottom', horizontalalignment='left',
#                        transform=ax.transAxes,
#                        color='yellow',weight='bold', fontsize=13)
            plt.subplots_adjust(left, bottom, right, top, wspace, hspace)
            if save_plot==True:
                pp.savefig(fig)            
        pp.close()
        if close_figs:
            plt.close('all')
    #for ll in range(L+adaptBias):
    #    print 'Sparsity=',np.mean(shapes[ll]>0)
            
   
   
   #%%  All Shapes slices        
    transpose_shape= True # should we transpose shape
    ComponentsInFig=3 # number of components in Figure
    index=0 #component display index
#        z_slices=[0,1,2,3,4,5,6,7,8] #which z slices to look at slice plots/videos
    z_slices=list(range(dims[min_dim+1])) #which z slices to look at slice plots/videos
    
    if plot_shapes_slices:            
        if save_plot==True:
            pp = PdfPages(Results_folder + 'Shapes_slices'+resultsName+'.pdf')
        for ll in range(L+adaptBias):
            if index==0:
                fig=plt.figure(figsize=(18, 11))
            for dd in range(len(z_slices)):                
                ax = plt.subplot(ComponentsInFig,len(z_slices),index*len(z_slices)+dd+1) 
                temp=shapes[ll].take(dd,axis=min_dim)
                if transpose_shape:
                    temp=np.transpose(temp)                                           
                    
                mi=np.min(shapes[ll])
                ma=np.max(shapes[ll])
                im=plt.imshow(temp,vmin=mi,vmax=ma,cmap=color_map)
                plt.setp(ax,xticks=[],yticks=[])
                
                if dd==0:
                    # component number
                    ax.text(0.02, 0.8, str(ll),
                    verticalalignment='bottom', horizontalalignment='left',
                    transform=ax.transAxes,
                    color='white',weight='bold', fontsize=13)
                    #sparsity
                    spar_str=str(np.round(np.mean(shapes[ll]>0)*100,2))+'%'
                    ax.text(0.02, 0.02, spar_str,
                    verticalalignment='bottom', horizontalalignment='left',
                    transform=ax.transAxes,
                    color='white',weight='bold', fontsize=13)
                    mn=int(np.floor(mi))        # colorbar min value
                    mx=int(np.ceil(ma))         # colorbar max value
                    md=old_div((mx-mn),2)
                    divider = make_axes_locatable(ax)
                    cax = divider.append_axes("bottom", size="5%", pad=0.05)
                    cb=plt.colorbar(im,cax=cax,orientation="horizontal")
                    cb.set_ticks([mn,md,mx])
                    cb.set_ticklabels([mn,md,mx])
                    #L^p
                    for p in range(2,2,2):
                        Lp=old_div((np.sum(shapes[ll]**p))**(old_div(1,float(p))),np.sum(shapes[ll]))
                        Lp_str=str(np.round(Lp*100,2))+'%' #'L'+str(p)+'='+
                        ax.text(0.02+p*0.15, 0.02, Lp_str,
                        verticalalignment='bottom', horizontalalignment='left',
                        transform=ax.transAxes,
                        color='yellow',weight='bold', fontsize=13)
                        
                    
            plt.subplots_adjust(left, bottom, right, top, wspace, hspace)
            index+=1
            if (ll%ComponentsInFig==(ComponentsInFig-1)) or ll==L+adaptBias-1:                    
                if save_plot==True:
                    pp.savefig(fig)    
                index=0
        pp.close()
        if close_figs:
            plt.close('all')
    #for ll in range(L+adaptBias):
    #    print 'Sparsity=',np.mean(shapes[ll]>0)
            
    #%% ###### Plot Individual neurons' shape projection with clustering
    a=ceil(sqrt(L+adaptBias))
    b=ceil(old_div((L+adaptBias),a))
    
    if plot_clustered_shape:
        from sklearn.cluster import spectral_clustering
        pp = PdfPages(Results_folder + 'ClusteredShapes'+resultsName+'.pdf')
        figs=[]
        for dd in range(len(shapes[0].shape)):
            figs.append(plt.figure(figsize=(18 , 11)))
        for ll in range(L):              
            ind=np.reshape(shapes[ll],(1,)+tuple(dims[1:]))>0
            temp=data[np.repeat(ind,dims[0],axis=0)].reshape(dims[0],-1)
            delta=1 #affinity trasnformation parameter
            clust=3 #number of cluster
            similarity=np.exp(old_div(-np.corrcoef(temp.T),delta))                    
            labels = spectral_clustering(similarity, n_clusters=clust, eigen_solver='arpack')
            ind2=np.array(np.nonzero(ind.reshape(-1))).reshape(-1)
            temp_shape=np.repeat(np.zeros_like(shapes[ll]).reshape(-1,1),clust,axis=1)
            for cc in range(clust):
                temp_shape[ind2[labels==cc],cc]=1
            temp_shape=temp_shape.reshape(tuple(dims[1:])+(clust,))

            for dd in range(len(shapes[0].shape)):
                current_fig=figs[dd]
                ax = current_fig.add_subplot(a,b,ll+1)
                if dd==2:
                    temp_shape=np.transpose(temp_shape,axes=[1,0,2,3])
                ax.imshow(temp_shape.max(dd))

                plt.setp(ax,xticks=[],yticks=[])
                plt.subplots_adjust(left, bottom, right, top, wspace, hspace)
        
        if save_plot==True:
            for dd in range(len(shapes[0].shape)):
                current_fig=figs[dd]
                pp.savefig(current_fig)
        pp.close()
        if close_figs:
            plt.close('all')
            
            

    #%% #####  Video Shapes
    if video_shapes:
        components=list(range(min(asarray([C,L]))))
        C=len(components)
        if restrict_support==True:
            shape_support=shapes[components[0]]>0            
            for cc in range(C):
                shape_support=np.logical_or(shape_support,shapes[components[cc]]>0)
            detrended_data=shape_support.reshape((1,)+tuple(dims[1:]))*detrended_data
        
        fig = plt.figure(figsize=(16,7))
        mi = 0
        ma = max(data)*scale
        #mi2 = 0
        #ma2 = max(shapes[ll])*max(activity[ll])
        
        ii=0
        #import colormaps as cmaps
        #cmap=cmaps.viridis
        cmap=color_map
        a=3
        b=1+C
        
        ax1 = plt.subplot(a,b,1)
        im1 = ax1.imshow(data[ii].max(0), vmin=mi, vmax=ma,cmap=cmap)
        title=ax1.set_title('Data')
        #plt.colorbar(im1)
        ax2=[] 
        ax4=[] 
        ax6=[]
        im2=[]
        im4=[]
        im6=[]
        
        for cc in range(C):
            ax2.append(plt.subplot(a,b,2+cc))
            comp=shapes[components[cc]].max(0)*activity_NonNegative[components[cc],ii]
            ma2=max(shapes[components[cc]].max(0))*max(activity_NonNegative[components[cc]])*scale
            im2.append(ax2[cc].imshow(comp,vmin=0,vmax=ma2,cmap=cmap))
        #ax2[0].set_title('Shape')
        #    plt.colorbar(im2)
        
        ax3 = plt.subplot(a,b,1+b)
        im3 = ax3.imshow(data[ii].max(1), vmin=mi, vmax=ma,cmap=cmap)
        
        #plt.colorbar(im3)
        
        for cc in range(C):
            ax4.append(plt.subplot(a,b,2+b+cc))
            comp=shapes[components[cc]].max(1)*activity_NonNegative[components[cc],ii]
            ma2=max(shapes[components[cc]].max(1))*max(activity_NonNegative[components[cc]])*scale
            im4.append(ax4[cc].imshow(comp,vmin=0,vmax=ma2,cmap=cmap))
        
        #plt.colorbar(im4)
        
        ax5 = plt.subplot(a,b,1+2*b)
        im5 = ax5.imshow(np.transpose(detrended_data[ii].max(2)), vmin=mi, vmax=ma,cmap=cmap)
        
        #plt.colorbar(im5)
        for cc in range(C):
            ax6.append(plt.subplot(a,b,2+2*b+cc))
            comp=np.transpose(shapes[components[cc]].max(2))*activity_NonNegative[components[cc],ii]
            ma2=max(shapes[components[cc]].max(2))*max(activity_NonNegative[components[cc]])*scale
            im6.append(ax6[cc].imshow(comp,vmin=0,vmax=ma2,cmap=cmap))
        
        #plt.colorbar(im6)
        
        fig.tight_layout()
        ComponentsActive=np.array([])
        for cc in range(C):
            ComponentsActive=np.append(ComponentsActive,np.nonzero(activity_NonNegative[components[cc]]))
        ComponentsActive=np.unique(ComponentsActive)
        
        def update(tt):
            ii=ComponentsActive[tt]
            im1.set_data(data[ii].max(0))        
            im3.set_data(data[ii].max(1))        
            im5.set_data(np.transpose(data[ii].max(2)))
        
            for cc in range(C): 
                im2[cc].set_data(shapes[components[cc]].max(0)*activity_NonNegative[components[cc],ii])
                im4[cc].set_data(shapes[components[cc]].max(1)*activity_NonNegative[components[cc],ii])
                im6[cc].set_data(np.transpose(shapes[components[cc]].max(2))*activity_NonNegative[components[cc],ii])
            title.set_text('Data, time = %.1f' % ii)
        
        if save_video==True:
            writer = animation.writers['ffmpeg'](fps=10)
            ani = animation.FuncAnimation(fig, update, frames=len(ComponentsActive), blit=True, repeat=False)
            if restrict_support==True:
                ani.save(Results_folder + 'Shapes_Restricted'+resultsName+'.mp4',dpi=dpi,writer=writer)
            else:                        
                ani.save(Results_folder + 'Shapes_'+resultsName+'.mp4',dpi=dpi,writer=writer)
        else:
            ani = animation.FuncAnimation(fig, update, frames=len(ComponentsActive), blit=True, repeat=False)
            plt.show()

    
    #%% ##### Plot denoised projection - Results

    if plot_residual_projections==True:
        
        dims=data.shape
        cmap=color_map         
        
        pic_residual=percentile(residual, 95, axis=0)
        pic_denoised = max_intensity(denoised_data, axis=0)
        pic_data=percentile(data, 95, axis=0)
        
        left  = 0.05 # the left side of the subplots of the figure
        right = 0.95   # the right side of the subplots of the figure
        bottom = 0.05   # the bottom of the subplots of the figure
        top = 0.95      # the top of the subplots of the figure
        wspace = 0.05   # the amount of width reserved for blank space between subplots
        hspace = 0.05  # the amount of height reserved for white space between subplots
        
        
        fig1=plt.figure(figsize=(11,18))
        mi=min(pic_data)
        ma=max(pic_data)
        ax = plt.subplot(311)
        im=ax.imshow(pic_data.max(0),vmin=mi,vmax=ma,cmap=cmap)
        ax.set_title('Data')
        plt.colorbar(im)
        plt.setp(ax,xticks=[],yticks=[])
        ax2 = plt.subplot(312)
        im2=ax2.imshow(max_intensity(pic_denoised,0),interpolation='None')
        ax2.set_title('Denoised')
        plt.setp(ax,xticks=[],yticks=[])
        plt.colorbar(im2)
        ax3 = plt.subplot(313)
        im3=ax3.imshow(pic_residual.max(0),cmap=cmap)
        ax3.set_title('Residual')
        plt.setp(ax,xticks=[],yticks=[])
        plt.colorbar(im3)
        plt.subplots_adjust(left, bottom, right, top, wspace, hspace)
        
        fig2=plt.figure(figsize=(11,18))
        mi=min(pic_data)
        ma=max(pic_data)
        ax = plt.subplot(311)
        im=ax.imshow(pic_data.max(1),vmin=mi,vmax=ma,cmap=cmap)
        ax.set_title('Data')
        plt.colorbar(im)
        plt.setp(ax,xticks=[],yticks=[])
        ax2 = plt.subplot(312)
        im2=ax2.imshow(max_intensity(pic_denoised,1),interpolation='None')
        ax2.set_title('Denoised')
        plt.colorbar(im2)
        plt.setp(ax,xticks=[],yticks=[])
        ax3 = plt.subplot(313)
        im3=ax3.imshow(pic_residual.max(1),cmap=cmap)
        ax3.set_title('Residual')
        plt.colorbar(im3)
        plt.setp(ax,xticks=[],yticks=[])
        plt.subplots_adjust(left, bottom, right, top, wspace, hspace)
        
        fig3=plt.figure(figsize=(11,18))
        mi=min(pic_data)
        ma=max(pic_data)
        ax = plt.subplot(311)
        im=ax.imshow(pic_data.max(2).T,vmin=mi,vmax=ma,cmap=cmap)
        ax.set_title('Data')
        plt.colorbar(im)
        plt.setp(ax,xticks=[],yticks=[])
        ax2 = plt.subplot(312)
        im2=ax2.imshow(np.transpose(max_intensity(pic_denoised,2),[1,0,2]),interpolation='None')
        ax2.set_title('denoised')
        plt.setp(ax,xticks=[],yticks=[])
        plt.colorbar(im2)
        ax3 = plt.subplot(313)
        im3=ax3.imshow(np.transpose(pic_residual.max(2)),cmap=cmap)
        ax3.set_title('Residual')
        plt.colorbar(im3)
        plt.setp(ax,xticks=[],yticks=[])
        plt.subplots_adjust(left, bottom, right, top, wspace, hspace)
    
        if save_plot==True:
            pp = PdfPages(Results_folder + 'Data_Denoised_Residual_Projections'+resultsName+'.pdf')
            pp.savefig(fig1)
            pp.savefig(fig2)
            pp.savefig(fig3)
            pp.close()
    
    
    #fig = plt.figure()
    #plt.plot(MSE_array)
    #plt.xlabel('Iteration')
    #plt.ylabel('MSE')
    #plt.show()
    
     #%% ##### Plot denoised slices - Results
#    z_slices=[0,2,4,6,8] #which z slices to look at slice plots/videos
    z_slices=list(range(dims[min_dim+1])) #which z slices to look at slice plots/videos
    D=len(z_slices)
    if plot_residual_slices==True:
        
        dims=data.shape
        cmap=color_map         
        
        pic_residual=percentile(residual, 95, axis=0)
        pic_denoised = max_intensity(denoised_data, axis=0)
        pic_data=percentile(data, 95, axis=0)
        
        
        a=3 #number of rows
        fig1=plt.figure(figsize=(18,11))
        mi=min(pic_data)
        ma=max(pic_data)
        for kk in range(D):        
            ax2 = plt.subplot(a,D,kk+1)
            temp=np.squeeze(np.take(pic_denoised,(z_slices[kk],),axis=min_dim))
            im2=ax2.imshow(temp,interpolation='None')
            ax2.set_title('Denoised')
            plt.setp(ax2,xticks=[],yticks=[])
            plt.colorbar(im2)
            ax = plt.subplot(a,D,kk+D+1)
            temp=np.squeeze(np.take(pic_data,(z_slices[kk],),axis=min_dim))
            im=ax.imshow(temp,vmin=mi,vmax=ma,cmap=cmap)
            ax.set_title('Data')
            plt.colorbar(im)
            plt.setp(ax,xticks=[],yticks=[])
            ax3 = plt.subplot(a,D,kk+2*D+1)
            temp=np.squeeze(np.take(pic_residual,(z_slices[kk],),axis=min_dim))
            im3=ax3.imshow(temp,cmap=cmap)
            ax3.set_title('Residual')
            plt.setp(ax3,xticks=[],yticks=[])
            plt.colorbar(im3)
            plt.subplots_adjust(left, bottom, right, top, wspace, hspace)
            
        
            if save_plot==True:
                pp = PdfPages(Results_folder + 'Data_Denoised_Residual_Slice_'+resultsName+'.pdf')
                pp.savefig(fig1)
                pp.close()
        
        
    #fig = plt.figure()
    #plt.plot(MSE_array)
    #plt.xlabel('Iteration')
    #plt.ylabel('MSE')
    #plt.show()


 
    #%% #####  2D Video Residual    
    if video_residual_2D:
        fig = plt.figure(figsize=(16,7))
        mi = 0
        ma = np.percentile(data,satuartion_percentile)
        mi3 = 0
        ma3 = old_div(ma,np.max([np.floor(old_div(ma,np.percentile(residual[residual>0],satuartion_percentile))),1]))

        ii=0
        #import colormaps as cmaps
        #cmap=cmaps.viridis
        cmap=color_map        
        
        a=1
        b=3
        
        im_array=[]
        temp=np.shape(data[ii])                   

        ax1 = plt.subplot(a,b,1)            

        pic=denoised_data[ii]
        im_array += [ax1.imshow(pic,interpolation='None')]
        ax1.set_title('Denoised')
        plt.setp(ax1,xticks=[],yticks=[])
        
        ax2 = plt.subplot(a,b,2)
        pic=data[ii]
            
        im_array += [ax2.imshow(pic, vmin=mi, vmax=ma,cmap=cmap)]
        title=ax2.set_title('Data')  
        plt.setp(ax2,xticks=[],yticks=[])
        divider = make_axes_locatable(ax2)
        cax2 = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(im_array[-1], cax=cax2)          

        
        
        ax3 = plt.subplot(a,b,3)            
        pic=residual[ii]   
        im_array += [ax3.imshow(pic, vmin=mi3, vmax=ma3,cmap=cmap)]
        ax3.set_title('Residual x' + '%.1f' % (old_div(ma,ma3)))
        plt.setp(ax3,xticks=[],yticks=[])
        divider = make_axes_locatable(ax3)
        cax3 = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(im_array[-1], cax=cax3)
        

#        fig.tight_layout()
        plt.subplots_adjust(left, bottom, right, top, wspace, hspace)
            
        def update(ii):
            im_array[0].set_data(denoised_data[ii])
            im_array[1].set_data(data[ii])        
            im_array[2].set_data(residual[ii])                     
            
            if frame_rate!=[]:
                title.set_text('Data, time = %.2f sec' % (old_div(ii,frame_rate)))
            else:
                title.set_text('Data, time = %.1f' % ii)
        
        if save_video==True:
            writer = animation.writers['ffmpeg'](fps=10)
            ani = animation.FuncAnimation(fig, update, frames=len(data), blit=False, repeat=False)
            ani.save(Results_folder + 'Data_Denoised_Residual_2D_' +resultsName+'.avi',dpi=dpi,writer=writer)
        else:
            ani = animation.FuncAnimation(fig, update, frames=len(data), blit=False, repeat=False)
            plt.show()  
   
    #%% #####  Video Projections Residual    
    if video_residual:
        fig = plt.figure(figsize=(16,7))
        mi = 0
        ma = max(data)*scale
        mi3 = 0
        ma3 = max(residual)*scale

        ii=0
        #import colormaps as cmaps
        #cmap=cmaps.viridis
        cmap=color_map
        
        spatial_dims_ind=list(range(len(dims)-1))
        D=len(spatial_dims_ind)
        a=D
        b=3
        
        im_array=[]
        transpose_flags=[]
        for kk in range(D):
            transpose_flags+= [False]
            temp=np.shape(data[ii].max(spatial_dims_ind[kk]))
            if temp[0]>temp[1]:
                transpose_flags[kk]=True    
                
        for kk in range(D):
            ax1 = plt.subplot(a,b,D*kk+1)            
            if transpose_flags[kk]==False:
                pic=max_intensity(denoised_data[ii],spatial_dims_ind[kk])
            else:
                pic=np.transpose(max_intensity(denoised_data[ii],spatial_dims_ind[kk]),[1,0,2])  
            im_array += [ax1.imshow(pic,interpolation='None')]
            ax1.set_title('Denoised')
            plt.colorbar(im_array[-1])
            plt.setp(ax1,xticks=[],yticks=[])
            
            ax2 = plt.subplot(a,b,D*kk+2)
            if transpose_flags[kk]==False:
                pic=data[ii].max(spatial_dims_ind[kk])
            else:
                pic=np.transpose(data[ii].max(spatial_dims_ind[kk]))
                
            im_array += [ax2.imshow(pic, vmin=mi, vmax=ma,cmap=cmap)]
            title=ax2.set_title('Data')            
            plt.colorbar(im_array[-1])
            plt.setp(ax2,xticks=[],yticks=[])
            
            ax3 = plt.subplot(a,b,D*kk+3)            
            if transpose_flags[kk]==False:
                pic=residual[ii].max(spatial_dims_ind[kk])
            else:
                pic=np.transpose(residual[ii].max(spatial_dims_ind[kk]))        
            im_array += [ax3.imshow(pic, vmin=mi3, vmax=ma3,cmap=cmap)]
            ax3.set_title('Residual')
            plt.colorbar(im_array[-1])
            plt.setp(ax3,xticks=[],yticks=[])

#        fig.tight_layout()
        plt.subplots_adjust(left, bottom, right, top, wspace, hspace)
            
        def update(ii):
            for kk in range(D):
                if transpose_flags[kk]==False:
                    im_array[kk*D].set_data(max_intensity(denoised_data[ii],spatial_dims_ind[kk]))
                    im_array[kk*D+1].set_data(data[ii].max(spatial_dims_ind[kk]))        
                    im_array[kk*D+2].set_data(residual[ii].max(spatial_dims_ind[kk]))                     
                else:
                    im_array[kk*D].set_data(np.transpose(max_intensity(denoised_data[ii],spatial_dims_ind[kk]),[1,0,2]))
                    im_array[kk*D+1].set_data(np.transpose(data[ii].max(spatial_dims_ind[kk])))        
                    im_array[kk*D+2].set_data(np.transpose(residual[ii].max(spatial_dims_ind[kk])))                     
            
            title.set_text('Data, time = %.1f' % ii)
        
        if save_video==True:
            writer = animation.writers['ffmpeg'](fps=10)
            ani = animation.FuncAnimation(fig, update, frames=len(data), blit=False, repeat=False)
            ani.save(Results_folder + 'Data_Denoised_Residual_Projections'+resultsName+'.mp4',dpi=dpi,writer=writer)
        else:
            ani = animation.FuncAnimation(fig, update, frames=len(data), blit=False, repeat=False)
            plt.show()  
            
    #%% #####  Video Slices Residual    
#    z_slices=[0,2,4,6,8] #which z slices to look at slice plots/videos    
    z_slices=list(range(dims[min_dim+1])) #which z slices to look at slice plots/videos
    
    if video_slices:
        fig = plt.figure(figsize=(16,7))
        mi = 0
        ma = np.percentil(data[data>0],satuartion_percentile)
        mi3 = 0
        ma3 = np.percentil(data[data>0],satuartion_percentile)

        ii=0
        #import colormaps as cmaps
        #cmap=cmaps.viridis
        cmap=color_map
        a=3
        
        D=len(z_slices) #number of spatial dimensions
        im_array=[]
        transpose_flag= True
                
        for kk in range(D):
            ax1 = plt.subplot(a,D,kk+1)            
            temp=np.squeeze(np.take(denoised_data[ii],(z_slices[kk],),axis=min_dim))
            if transpose_flag==False:
                pic=temp
            else:
                pic=np.transpose(temp,[1,0,2])  
            im_array += [ax1.imshow(pic,interpolation='None')]
            ax1.set_title('Denoised, z='+ str(z_slices[kk]+1))
            plt.colorbar(im_array[-1])
            plt.setp(ax1,xticks=[],yticks=[])
            
            ax2 = plt.subplot(a,D,kk+D+1)
            temp=np.squeeze(np.take(data[ii],(z_slices[kk],),axis=min_dim))
            if transpose_flag==False:
                pic=temp
            else:
                pic=np.transpose(temp)
                
            im_array += [ax2.imshow(pic, vmin=mi, vmax=ma,cmap=cmap)]
            title=ax2.set_title('Data')            
            plt.colorbar(im_array[-1])
            plt.setp(ax2,xticks=[],yticks=[])
            
            ax3 = plt.subplot(a,D,kk+2*D+1) 
            temp=np.squeeze(np.take(residual[ii],(z_slices[kk],),axis=min_dim))
            if transpose_flag==False:
                pic=temp
            else:
                pic=np.transpose(temp)       
            im_array += [ax3.imshow(pic, vmin=mi3, vmax=ma3,cmap=cmap)]
            ax3.set_title('Residual')
            plt.colorbar(im_array[-1])
            plt.setp(ax3,xticks=[],yticks=[])

#        fig.tight_layout()
        plt.subplots_adjust(left, bottom, right, top, wspace, hspace)        
        
        def update(ii):
            for kk in range(D):
                temp1=np.squeeze(np.take(denoised_data[ii],(z_slices[kk],),axis=min_dim))
                temp2=np.squeeze(np.take(data[ii],(z_slices[kk],),axis=min_dim))
                temp3=np.squeeze(np.take(residual[ii],(z_slices[kk],),axis=min_dim))
                if transpose_flag==False:                    
                    im_array[a*kk].set_data(temp1)
                    im_array[a*kk+1].set_data(temp2)        
                    im_array[a*kk+2].set_data(temp3)                     
                else:
                    im_array[a*kk].set_data(np.transpose(temp1,[1,0,2]))
                    im_array[a*kk+1].set_data(np.transpose(temp2))        
                    im_array[a*kk+2].set_data(np.transpose(temp3))                     
            
            title.set_text('Data, time = %.1f' % ii)
        
        if save_video==True:
            writer = animation.writers['ffmpeg'](fps=10)
            ani = animation.FuncAnimation(fig, update, frames=len(data), blit=False, repeat=False)
            ani.save(Results_folder + 'Data_Denoised_Residual_Slices'+resultsName+'.mp4',dpi=dpi,writer=writer)
        else:
            ani = animation.FuncAnimation(fig, update, frames=len(data), blit=False, repeat=False)
            plt.show()              
def LocalNMF(data, centers, sig, NonNegative=True,FinalNonNegative=True,verbose=False,adaptBias=True,TargetAreaRatio=[],estimateNoise=False,
             PositiveError=False,MedianFilt=False,Connected=False,FixSupport=False, WaterShed=False,SmoothBkg=False,FineTune=True,Deconvolve=False,
             SigmaMask=[],updateLambdaIntervals=2,updateRhoIntervals=2,addComponentsIntervals=1,bkg_per=20,SigmaBlur=[],
             iters=10,iters0=[30], mbs=[1], ds=1,lam1_s=0,lam1_t=0,lam2_s=0,lam2_t=0):
    """
    Parameters
    ----------
    data : array, shape (T, X, Y[, Z])
        block of the data
    centers : array, shape (L, D)
        L centers of suspected neurons where D is spatial dimension (2 or 3)
    sig : array, shape (D,)
        size of the gaussian kernel in different spatial directions
    NonNegative : boolean
        if True, neurons activity should be considered as non-negative
    FinalNonNegative : boolean
        if False, last activity iteration is done without non-negativity constraint, even if NonNegative==True       
    verbose : boolean
        print progress and record MSE if true (about 2x slower)
    adaptBias : boolean
        subtract rank 1 estimate of bias (background)
    TargetAreaRatio : list of length 2
        Lower and upper bounds on sparsity of non-background components
    estimateNoise : boolean
        estimate noise variance and use it determine if to add components, and to modify sparsity by affecting lam1_s (does not work very well)
    PositiveError : boolean
        do not allow pixels in which the residual (summed over time) becomes negative, by increasing lam1_s in these pixels
    MedianFilt : boolean
        do median filter of spatial components 
    Connected: boolean
        impose connectedness of spatial component by keeping only the largest non-zero connected component in each iteration of HALS
    WaterShed: boolean
        impose that each spatial component has a single watershed region
    SmoothBkg: boolean
        Remove local peaks from background component
    FixSupport : boolean
        do not allow spatial components to be non-zero where sub-sampled spatial components are zero
    FineTune :  boolean
        fine tune main iterations on full data, if not, use (last) downsampled data
    Deconvolve : boolean
        Deconvolve activity to get smoothed (denoised) calcium trace. This is done only on the main itreations, and if FineTune=True
    SigmaMask : scalar or empty
        if not [], then update masks so that they are SigmaMasks around non-zero support of shapes
    SigmaBlur : scalar
        if not [], then de-blur spatial components using Gaussian Kernel of this width
    updateLambdaIntervals : int
        update lam1_s every this number of HALS iterations, to match contraints
    updateRhoIntervals : int
        decrease rho, update rate of lam1_s, every this number of updateLambdaIntervals HALS iterations (only active during main iterations)
    addComponentsIntervals : int
        add new component, if possible, every this number of updateLambdaIntervals HALS iterations (only active during sub-sampled iterations)
    bkg_per : float
        the background is intialized at this height (percentrilce image)
    iters : int
        number of final iterations on whole data
    iters0 : list
        numbers of initial iterations on subset
    mbs : list
        minibatchsizes for temporal downsampling 
    ds : int or list
        factor for spatial downsampling, can be an integer or a list of the size of spatial dimensions
    lam1_s : float
        L_1 regularization constant for sparsity of shapes
    lam2_s : float
        L_2 regularization constant for sparsity of shapes
    lam_t : float
        L_1 regularization constant for sparsity of activity
    lam2_t : float
        L_2 regularization constant for sparsity of activity

    Returns
    -------
    MSE_array : list (empty if verbose is False)
        Mean square error during algorithm operation
    shapes : array, shape (L+adaptBias, X, Y (,Z))
        the neuronal shape vectors (empty if no components found)
    activity : array, shape (L+adaptBias, T)
        the neuronal activity for each shape (empty if no components found)
    boxes : array, shape (L, D, 2)
        edges of the boxes in which each neuronal shapes lie (empty if no components found)
    """
    
    # Catch Errors
    if ds!=1 and SigmaBlur!=[]:
        raise NameError('case ds!=1 and SigmaBlur!=[] no yet written in NMF code')
        
    
    # Initialize Parameters
    dims = data.shape # data dimensions
    D = len(dims) #number of data dimensions
    R = 3 * asarray(sig)  # size of bounding box is 3 times size of neuron
    L = len(centers) # number of components (not including background)
    inner_iterations=10 # number of iterations in inners loops
    shapes = [] #array of spatial components
    mask = [] # binary array, support of spatial components
    boxes = zeros((L, D - 1, 2), dtype=int) #initial support of spatial components
    MSE_array = [] #CNMF residual error
    mb = mbs[0] if iters0[0] > 0 else 1 
    activity = zeros((L, old_div(dims[0], mb))) #array of temporal components
    lam1_s0=np.copy(lam1_s) #intial spatial sparsity (l1) parameters
    if TargetAreaRatio!=[]:
        if TargetAreaRatio[0]>TargetAreaRatio[1]:            
            print('WARNING -  TargetAreaRatio[0]>TargetAreaRatio[1] !!!')
    if iters0[0] == 0:
        ds = 1

        
### Initialize shapes, activity, and residual ###        
    
    data0,dims0=DownScale(data,mb,ds) #downscaled data and dimensions
    if isinstance(ds,int):
        ds=ds*np.ones(D-1)

    if D == 4: #downscale activity
        activity = data0[:, list(map(int, old_div(centers[:, 0], ds[0]))), list(map(int, old_div(centers[:, 1], ds[1]))),
                         list(map(int, old_div(centers[:, 2], ds[2])))].T
    else:
        activity = data0[:, list(map(int, old_div(centers[:, 0], ds[0]))), list(map(int, old_div(centers[:, 1], ds[1])))].T
        
    data0 = data0.reshape(dims0[0], -1) #reshape data0 to more convient timexspace form
    Energy0=np.sum(data0**2,axis=0) #data0 energy per pixel
    data0sum=np.sum(data0,axis=0) # for sign check later

    data = data.astype('float').reshape(dims[0], -1) #reshape data to more convient timexspace form
    datasum=np.sum(data,axis=0)# for sign check later
    
    # float is faster than float32, presumable float32 gets converted later on
    # to float again and again
    Energy=np.sum((data**2),axis=0) #data energy per pixel
    
    # extract shapes and activity from given centers
    for ll in range(L):
        boxes[ll] = GetBox(old_div(centers[ll], ds), old_div(R, ds), dims0[1:])
        temp = zeros(dims0[1:])
        temp[[slice(*a) for a in boxes[ll]]]=1
        mask += np.where(temp.ravel())
        temp = [old_div((arange(int(old_div(dims[i + 1], ds[i]))) -int( old_div(centers[ll][i], ds[i]))) ** 2, (2 * (old_div(sig[i], ds[i])) ** 2))
                for i in range(D - 1)]
        temp = exp(-sum(ix_(*temp)))
        temp.shape = (1,) + dims0[1:]
        temp = RegionCut(temp, boxes[ll])
        shapes.append(temp[0])
    S = zeros((L + adaptBias, prod(dims0[1:]))) #shape component
    for ll in range(L):
        S[ll] = RegionAdd(
            zeros((1,) + dims0[1:]), shapes[ll].reshape(1, -1), boxes[ll]).ravel()
    if adaptBias:
        # Initialize background as bkg_per percentile
        S[-1] = percentile(data0, bkg_per, 0)
        activity = np.r_[activity, ones((1, dims0[0]))]
    
    lam1_s=lam1_s0*np.ones_like(S)*mbs[0] #intialize sparsity parameters


### Get shape estimates on subset of data ###
    if iters0[0] > 0:
        for it in range(len(iters0)):
            if estimateNoise:
                sn_target,sn_std= GetSnPSDArray(data0)#target noise level
            else:
                sn_target=np.zeros(prod(dims0[1:]))
                sn_std=sn_target
            MSE_target = np.mean(sn_target**2)
            ES=ExponentialSearch(lam1_s) #object to update sparsity parameters
            lam1_s=ES.lam
            for kk in range(iters0[it]):
                # update sparisty parameters     
                if kk%updateLambdaIntervals==0:                 
                    sn=old_div(np.sqrt(Energy0-2*np.sum(np.dot(activity,data0)*S,axis=0)+np.sum(np.dot(np.dot(activity,activity.T),S)*S,axis=0)),dims0[0]) # efficient way to calcuate MSE per pixel
        
                    delta_sn=sn-sn_target # noise margin
                    signcheck=(data0sum-np.dot(np.sum(activity.T,axis=0),S))<0
                    if PositiveError: #obsolete
                        delta_sn[signcheck]=-float("inf") # residual should not have negative pixels, so we increase lambda for these pixels
                    
                    if len(S)==0:
                        spars=0
                    else:
                        spars=np.mean(S>0,axis=1)
                        
                        temp=repeat(delta_sn.reshape(1,-1),L+adaptBias,axis=0) 
    
                        if TargetAreaRatio==[]:  
                            cond_decrease=temp>sn_std
                            cond_increase=temp<-sn_std
                        else:
                            if adaptBias:
                                spars[-1]=old_div((TargetAreaRatio[1]+TargetAreaRatio[0]),2) # ignore sparsity target for background (bias) component  
                            temp2=repeat(spars.reshape(-1,1),len(S[0]),axis=1)
                            cond_increase=np.logical_or(temp2>TargetAreaRatio[1],temp<-sn_std)
                            cond_decrease=np.logical_and(temp2<TargetAreaRatio[0],temp>sn_std)
        
                        ES.update(cond_decrease,cond_increase)    
                        lam1_s=ES.lam
                        
                    #Print residual error and additional information
                    MSE = np.mean(sn**2)
                    
                    if verbose and L>0:                       
                        print(' MSE = {0:.6f}, Target MSE={1:.6f},Sparsity={2:.4f},lam1_s={3:.6f}'.format(MSE,MSE_target,np.mean(spars[:L]),np.mean(lam1_s)))
                    
                    #add a new component
                    if (kk%addComponentsIntervals==0) and (kk!=iters0[it]-1):
                        
                        delta_sn[signcheck]=-float("inf") # residual should not have negative pixels
                        new_cent=np.argmax(delta_sn) #should I smooth the data a bit first?
                        MSE_std=np.mean(sn_std**2)
                        checkNoZero= not((0 in np.sum(activity,axis=1)) and (0 in np.sum(S,axis=1)))
                        if ((MSE-MSE_target>2*MSE_std) and checkNoZero and (delta_sn[new_cent]>sn_std[new_cent])):                            
                            S, activity, mask,centers,boxes,L=addComponent(new_cent,data0,dims0,old_div(R,ds),S, activity, mask,centers,boxes,adaptBias)
                            new_lam=lam1_s0*np.ones_like(data0[0,:]).reshape(1,-1)
                            lam1_s=np.insert(lam1_s,0,values=new_lam,axis=0)
                            ES=ExponentialSearch(lam1_s) #we need to restart exponential search each time we add a component
                            
                #apply additional constraints/processing                            
                if SigmaBlur==[]:
                    S = HALS4shape(data0, S, activity,mask,lam1_s,lam2_s,adaptBias,inner_iterations)
                else: #obsolete
                    S=FISTA4shape(data0, S, activity,mask,lam1_s,adaptBias,SigmaBlur,dims0)
                
                if Connected==True:
                    S=LargestConnectedComponent(S,dims0,adaptBias)
                if WaterShed==True:
                    S=LargestWatershedRegion(S,dims0,adaptBias)
                activity = HALS4activity(data0, S, activity,NonNegative,lam1_t,lam2_t,dims0,SigmaBlur,inner_iterations)                                
                if SigmaMask!=[]:
                    mask=GrowMasks(S,mask,boxes,dims0,adaptBias,SigmaMask)
                S, activity, mask,centers,boxes,ES,L=RenormalizeDeleteSort(S, activity, mask,centers,boxes,ES,adaptBias,MedianFilt)
                lam1_s=ES.lam
                if SmoothBkg==True:
                    S=SmoothBackground(S,dims0,adaptBias,tuple(old_div(np.array(sig),np.array(ds))))
                
                print('Subsampled iteration',kk,'it=',it,'L=',L)
            
            # use next (smaller) value for temporal downscaling
            if it < len(iters0) - 1:
                mb = mbs[it + 1]
                data0 = data[:len(data) / mb * mb].reshape(-1, mb, prod(dims[1:])).mean(1)
                if D==4:
                    data0 = data0.reshape(len(data0), int(old_div(dims[1], ds[0])), ds[0], int(old_div(dims[2], ds[1])), ds[1],
                                          int(old_div(dims[3], ds[2])), ds[2]).mean(-1).mean(-2).mean(-3)                    
                else:
                    data0 = data0.reshape(len(data0), int(old_div(dims[1], ds[0])), ds[0], int(old_div(dims[2], ds[1])),
                                          ds[1]).mean(-1).mean(-2)
                data0.shape = (len(data0), -1)
                
                activity = ones((L + adaptBias, len(data0))) * activity.mean(1).reshape(-1, 1)
                lam1_s=lam1_s*mbs[it+1]/mbs[it]
                activity = HALS4activity(data0, S, activity,NonNegative,lam1_t,lam2_t,dims0,SigmaBlur,30)
                S, activity, mask,centers,boxes,ES,L=RenormalizeDeleteSort(S, activity, mask,centers,boxes,ES,adaptBias,MedianFilt)
                lam1_s=ES.lam

    ### Stop adding components ###
        if L==0: #if no non-background components found, return empty arrays
            return [], [], [], []
        
        if FineTune: ### Upscale Back to full data ##
            activity = ones((L + adaptBias, dims[0])) * activity.mean(1).reshape(-1, 1)
            data0=data
            dims0=dims
            if D==4:
                S = repeat(repeat(repeat(S.reshape((-1,) + dims0[1:]), ds[0], 1), ds[1], 2), ds[2], 3)
                lam1_s= repeat(repeat(repeat(lam1_s.reshape((-1,) + dims0[1:]), ds[0], 1), ds[1], 2), ds[2], 3)
            else:
                S = repeat(repeat(S.reshape((-1,) + dims0[1:]), ds[0], 1), ds[1], 2)
                lam1_s= repeat(repeat(lam1_s.reshape((-1,) + dims0[1:]), ds[0], 1), ds[1], 2)
            for dd in range(1,D):
                while S.shape[dd]<dims[dd]:
                    shape_append=np.array(S.shape)
                    shape_append[dd]=1
                    S=np.append(S,values=np.take(S,-1,axis=dd).reshape(shape_append),axis=dd)
                    lam1_s=np.append(lam1_s,values=np.take(lam1_s,-1,axis=dd).reshape(shape_append),axis=dd)
            S=S.reshape(L + adaptBias, -1)
            lam1_s=lam1_s.reshape(L+ adaptBias,-1)
            for ll in range(L):
                boxes[ll] = GetBox(centers[ll], R, dims[1:])
                temp = zeros(dims[1:])
                temp[[slice(*a) for a in boxes[ll]]] = 1
                mask[ll] = np.where(temp.ravel())[0]
            
            if FixSupport: #obsolete
                for ll in range(L):
                    lam1_s[ll,S[ll]==0]=float("inf")
                
            
            ES=ExponentialSearch(lam1_s)
            activity = HALS4activity(data0, S, activity,NonNegative,lam1_t,lam2_t,dims0,SigmaBlur, 30)
            S, activity, mask,centers,boxes,ES,L=RenormalizeDeleteSort(S, activity, mask,centers,boxes,ES,adaptBias,MedianFilt)
            lam1_s=ES.lam
            
            if estimateNoise:
                sn_target,sn_std= GetSnPSDArray(data0)#target noise level
            else:
                sn_target=np.zeros(prod(dims0[1:]))
                sn_std=sn_target
            MSE_target = np.mean(sn_target**2)
            MSE_std=np.mean(sn_std**2)
    #        MSE = np.mean((data0-np.dot(activity.T,S))**2)
        
#### Main Loop ####
  
    print('starting main NMF loop')
    for kk in range(iters):
        lam1_s=ES.lam #update sparsity parameters
        if SigmaBlur==[]:
            S = HALS4shape(data0, S, activity,mask,lam1_s,lam2_s,adaptBias,inner_iterations)
        else: #obsolete
            S = FISTA4shape(data0, S, activity,mask,lam1_s,adaptBias,SigmaBlur,dims0)
        #apply additional constraints/processing 
        if Connected==True:            
            S=LargestConnectedComponent(S,dims0,adaptBias)
        if WaterShed==True:
            S=LargestWatershedRegion(S,dims0,adaptBias)
        if kk==iters-1:
            if FinalNonNegative==False:
                NonNegative=False
        activity = HALS4activity(data0, S, activity,NonNegative,lam1_t,lam2_t,dims0,SigmaBlur,inner_iterations)
        if FineTune and Deconvolve:
            for ll in range(L):
                if np.sum(np.abs(activity[ll])>0)>30: #make sure there is enough signal before we try to deconvolve
                    activity[ll], _, _, _, _ = deconvolve(activity[ll], penalty=0)
 
        if SigmaMask!=[]:
            mask=GrowMasks(S,mask,boxes,dims0,adaptBias,SigmaMask)
        S, activity, mask,centers,boxes,ES,L=RenormalizeDeleteSort(S, activity, mask,centers,boxes,ES,adaptBias,MedianFilt)
        
        # Measure MSE and update sparsity parameters
        print('main iteration kk=',kk,'L=',L)
        if (kk+1)%updateLambdaIntervals==0:            
            sn=np.sqrt(old_div((Energy-2*np.sum(np.dot(activity,data0)*S,axis=0)+np.sum(np.dot(np.dot(activity,activity.T),S)*S,axis=0)),dims0[0]))
            delta_sn=sn-sn_target
            MSE = np.mean(sn**2)
            
            signcheck=(datasum-np.dot(np.sum(activity.T,axis=0),S))<0
            if PositiveError: #obsolete
                delta_sn[signcheck]=-float("inf") # residual should not have negative pixels, so we increase lambda for these pixels
            
            if S==[]:
                spars=0
            else:
                spars=np.mean(S>0,axis=1)
                
            temp=repeat(delta_sn.reshape(1,-1),L+adaptBias,axis=0) 

            if TargetAreaRatio==[]:  
                cond_decrease=temp>sn_std
                cond_increase=temp<-sn_std
            else:
                if adaptBias:
                    spars[-1]=old_div((TargetAreaRatio[1]+TargetAreaRatio[0]),2) # ignore sparsity target for background (bias) component  
                temp2=repeat(spars.reshape(-1,1),len(S[0]),axis=1)
                cond_increase=np.logical_or(temp2>TargetAreaRatio[1],temp<-sn_std)
                cond_decrease=np.logical_and(temp2<TargetAreaRatio[0],temp>sn_std)
            
            
            ES.update(cond_decrease,cond_increase)
            lam1_s=ES.lam
            if kk<old_div(iters,3): #restart exponential search unless enough iterations have passed
                ES=ExponentialSearch(lam1_s)                
            else:
                if not(np.any(cond_increase) or np.any(cond_decrease)):
                    print('sparsity target reached')
                    break
                if L+adaptBias>1: # if we have more then one component just keep exponitiated grad descent instead
                    if (kk+1)%updateRhoIntervals==0: #update rho every updateRhoIntervals if we are still not converged
                        if np.any(spars[:L]<TargetAreaRatio[0]) or np.any(spars[:L]>TargetAreaRatio[1]):
                            ES.rho=2-old_div(1,(ES.rho))
                            print('rho=',ES.rho)
                    ES=ExponentialSearch(lam1_s,rho=ES.rho)
            
            # prinst MSE and other information
            if verbose:             
                print(' MSE = {0:.6f}, Target MSE={1:.6f},Sparsity={2:.4f},lam1_s={3:.6f}'.format(MSE,MSE_target,np.mean(spars[:L]),np.mean(lam1_s)))
                if kk == (iters - 1):
                    print('Maximum iteration limit reached')
                MSE_array.append(MSE)
    
    # Some post-processing 
    S=S.reshape((-1,) + dims[1:])
    S,activity,L=PruneComponents(S,activity,L) #prune "bad" components
    if len(S)>1:
        S,activity,L=MergeComponents(S,activity,L,threshold=0.9,sig=10)    #merge very similar components
        if not FineTune:
            activity = ones((L + adaptBias, dims[0])) * activity.mean(1).reshape(-1, 1) #extract activity from full data
        activity=HALS4activity(data, S.reshape((len(S),-1)), activity,NonNegative,lam1_t,lam2_t,dims0,SigmaBlur,iters=30)
    
    return asarray(MSE_array), S, activity


# example to check code works


#T = 1000
#X = 201
#Y = 101
#data = np.random.randn(T, X, Y)
#centers = asarray([[40, 30]])
#data[:, 30:45, 25:33] += 2*np.sin(np.array(range(T))/200).reshape(-1,1,1)*np.ones([T,15,8])
#sig = [300, 300]
#
#MSE_array, shapes, activity, boxes = LocalNMF( 
#    data, centers, sig, NonNegative=True, verbose=True,lam1_s=0.1,adaptBias=True)
#
#
#import matplotlib.pyplot as plt
#plt.imshow(shapes[0])
#
#for ll in range(len(shapes)):
#    print np.mean(shapes[ll]>0)