-
Notifications
You must be signed in to change notification settings - Fork 1
/
BlockLocalNMF.py
492 lines (430 loc) · 24.4 KB
/
BlockLocalNMF.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals
from __future__ import absolute_import
from builtins import int
from future import standard_library
standard_library.install_aliases()
from builtins import map
from builtins import range
from past.utils import old_div
from numpy import asarray, percentile, zeros, ones, ix_, arange, exp, prod, repeat
import numpy as np
from BlockLocalNMF_AuxilaryFunctions import HALS4activity, HALS4shape,RenormalizeDeleteSort,addComponent,GetBox, \
RegionAdd,RegionCut,DownScale,LargestConnectedComponent,LargestWatershedRegion,SmoothBackground,GetSnPSDArray,ExponentialSearch,GrowMasks,FISTA4shape
from AuxilaryFunctions import PruneComponents,MergeComponents,make_sure_path_exists
import sys
OASIS_path='OASIS/'
make_sure_path_exists(OASIS_path)
sys.path.append(OASIS_path)
from functions import deconvolve
def LocalNMF(data, centers, sig, NonNegative=True,FinalNonNegative=True,verbose=False,adaptBias=True,TargetAreaRatio=[],estimateNoise=False,
PositiveError=False,MedianFilt=False,Connected=False,FixSupport=False, WaterShed=False,SmoothBkg=False,FineTune=True,Deconvolve=False,
SigmaMask=[],updateLambdaIntervals=2,updateRhoIntervals=2,addComponentsIntervals=1,bkg_per=20,SigmaBlur=[],MergeThreshold_activity=1,MergeThreshold_shapes=1,
iters=10,iters0=[30], mbs=[1], ds=1,lam1_s=0,lam1_t=0,lam2_s=0,lam2_t=0):
"""
Parameters
----------
data : array, shape (T, X, Y[, Z])
block of the data
centers : array, shape (L, D)
L centers of suspected neurons where D is spatial dimension (2 or 3)
sig : array, shape (D,)
size of the gaussian kernel in different spatial directions
NonNegative : boolean
if True, neurons activity should be considered as non-negative
FinalNonNegative : boolean
if False, last activity iteration is done without non-negativity constraint, even if NonNegative==True
verbose : boolean
print progress and record MSE if true (about 2x slower)
adaptBias : boolean
subtract rank 1 estimate of bias (background)
TargetAreaRatio : list of length 2
Lower and upper bounds on sparsity of non-background components
estimateNoise : boolean
estimate noise variance and use it determine if to add components, and to modify sparsity by affecting lam1_s (does not work very well)
PositiveError : boolean
do not allow pixels in which the residual (summed over time) becomes negative, by increasing lam1_s in these pixels
MedianFilt : boolean
do median filter of spatial components
Connected: boolean
impose connectedness of spatial component by keeping only the largest non-zero connected component in each iteration of HALS
WaterShed: boolean
impose that each spatial component has a single watershed region
SmoothBkg: boolean
Remove local peaks from background component
FixSupport : boolean
do not allow spatial components to be non-zero where sub-sampled spatial components are zero
FineTune : boolean
fine tune main iterations on full data, if not, use (last) downsampled data
Deconvolve : boolean
Deconvolve activity to get smoothed (denoised) calcium trace. This is done only on the main itreations, and if FineTune=True
SigmaMask : scalar or empty
if not [], then update masks so that they are non-zero a radius of SigmaMasks around previous non-zero support of shapes
SigmaBlur : scalar
if not [], then de-blur spatial components using Gaussian Kernel of this width
updateLambdaIntervals : int
update lam1_s every this number of HALS iterations, to match contraints
updateRhoIntervals : int
decrease rho, update rate of lam1_s, every this number of updateLambdaIntervals HALS iterations (only active during main iterations)
addComponentsIntervals : int
add new component, if possible, every this number of updateLambdaIntervals HALS iterations (only active during sub-sampled iterations)
bkg_per : float in the range [0,100]
the background is intialized at this height (percentrilce image)
iters : int
number of final iterations on whole data
iters0 : list
numbers of initial iterations on subset
mbs : list
minibatchsizes for temporal downsampling
ds : int or list
factor for spatial downsampling, can be an integer or a list of the size of spatial dimensions
lam1_s : float
L_1 regularization constant for sparsity of shapes
lam2_s : float
L_2 regularization constant for sparsity of shapes
lam_t : float
L_1 regularization constant for sparsity of activity
lam2_t : float
L_2 regularization constant for sparsity of activity
MergeThreshold_activity: float between 0 and 1
merge components if activity is correlated above the this threshold (and sufficiently close)
MergeThreshold_shapes: float between 0 and 1
merge components if shapes are correlated above the this threshold (and sufficiently close)
Returns
-------
MSE_array : list (empty if verbose is False)
Mean square error during algorithm operation
shapes : array, shape (L+adaptBias, X, Y (,Z))
the neuronal shape vectors (empty if no components found)
activity : array, shape (L+adaptBias, T)
the neuronal activity for each shape (empty if no components found)
boxes : array, shape (L, D, 2)
edges of the boxes in which each neuronal shapes lie (empty if no components found)
"""
# Catch Errors
if ds!=1 and SigmaBlur!=[]:
raise NameError('case ds!=1 and SigmaBlur!=[] no yet written in NMF code')
# Initialize Parameters
dims = data.shape # data dimensions
D = len(dims) #number of data dimensions
R = (3 * asarray(sig)).astype('uint8') # size of bounding box is 3 times size of neuron
L = len(centers) # number of components (not including background)
inner_iterations=10 # number of iterations in inners loops
shapes = [] #array of spatial components
mask = [] # binary array, support of spatial components
boxes = zeros((L, D - 1, 2), dtype=int) #initial support of spatial components
MSE_array = [] #CNMF residual error
mb = mbs[0] if iters0[0] > 0 else 1
activity = zeros((L, old_div(dims[0], mb))) #array of temporal components
lam1_s0=np.copy(lam1_s) #intial spatial sparsity (l1) parameters
if TargetAreaRatio!=[]:
if TargetAreaRatio[0]>TargetAreaRatio[1]:
print('WARNING - TargetAreaRatio[0]>TargetAreaRatio[1] !!!')
if iters0[0] == 0:
ds = 1
### Initialize shapes, activity, and residual ###
data0,dims0=DownScale(data,mb,ds) #downscaled data and dimensions
if isinstance(ds,int):
ds=(ds*np.ones(D-1)).astype('uint8')
if D == 4: #downscale activity
activity = data0[:, list(map(int, old_div(centers[:, 0], ds[0]))), list(map(int, old_div(centers[:, 1], ds[1]))),
list(map(int, old_div(centers[:, 2], ds[2])))].T
else:
activity = data0[:, list(map(int, old_div(centers[:, 0], ds[0]))), list(map(int, old_div(centers[:, 1], ds[1])))].T
data0 = data0.reshape(dims0[0], -1) #reshape data0 to more convient timexspace form
comp_method=None
M=30 #compression ratio
if comp_method == 'subsample':
data0 = data0[np.linspace(0, dims0[0] - 1, M).astype('int')]
dims0 = (M,) + dims0[1:]
elif comp_method == 'random':
np.random.seed(1)
# Mariano Tepper, Guillermo Sapiro: COMPRESSED NONNEGATIVE MATRIX
# FACTORIZATION IS FAST AND ACCURATE
Om = np.random.randn(np.prod(dims0[1:]), M).astype('float32')
# B = data0.dot(data0.T.dot(data0.dot(Om)))
B = data0.dot(Om)
Lmatrix = np.linalg.qr(B)[0]
Om = np.random.randn(dims0[0], M).astype('float32')
# B = data0.T.dot(data0.dot(data0.T.dot(Om)))
B = data0.T.dot(Om)
Rmatrix = np.linalg.qr(B)[0].T
dataL = Lmatrix.T.dot(data0)
dataR = data0.dot(Rmatrix.T)
#check non-negativity
print(np.min(dataL))
print(np.min(dataR))
elif comp_method == 'svd':
if mb > 1:
data_dec = data0.copy()
COV = data0.dot(data0.T)
_, V = np.eigh(COV, eigvals=(len(COV) - M, len(COV) - 1))
data0 = V.T.dot(data0)
dims0 = (M,) + dims0[1:]
# print(np.min(data0))
if comp_method is not None:
if D == 4: #downscale activity
activity = data0.reshape(dims0)[:, centers[:, 0].astype('int'), centers[:, 1].astype('int'),centers[:, 2].astype('int')].T
else:
activity = data0.reshape(dims0)[:, centers[:, 0].astype('int'), centers[:, 1].astype('int')].T
Energy0=np.sum(data0**2,axis=0) #data0 energy per pixel
data0sum=np.sum(data0,axis=0) # for sign check later
data = data.astype('float').reshape(dims[0], -1) #reshape data to more convient timexspace form
datasum=np.sum(data,axis=0)# for sign check later
# float is faster than float32, presumable float32 gets converted later on
# to float again and again
Energy=np.sum((data**2),axis=0) #data energy per pixel\
# extract shapes and activity from given centers
for ll in range(L):
boxes[ll] = GetBox(old_div(centers[ll], ds), old_div(R, ds), dims0[1:])
temp = zeros(dims0[1:])
temp[[slice(*a) for a in boxes[ll]]]=1
mask += np.where(temp.ravel())
temp = [old_div((arange(int(old_div(dims[i + 1], ds[i]))) -int( old_div(centers[ll][i], ds[i]))) ** 2, (2 * (old_div(sig[i], ds[i])) ** 2))
for i in range(D - 1)]
temp = exp(-sum(ix_(*temp)))
temp.shape = (1,) + dims0[1:]
temp = RegionCut(temp, boxes[ll])
shapes.append(temp[0])
S = zeros((L + adaptBias, prod(dims0[1:]))) #shape component
for ll in range(L):
S[ll] = RegionAdd(
zeros((1,) + dims0[1:]), shapes[ll].reshape(1, -1), boxes[ll]).ravel()
if adaptBias:
# Initialize background as bkg_per percentile
if comp_method == 'svd':
activity = np.r_[activity, V.sum(0).reshape(1, -1)]
S[-1] = np.percentile(data_dec, bkg_per, 0) if mb > 1 else np.percentile(data, bkg_per, 0)
else:
activity = np.r_[activity, ones((1, dims0[0]), dtype='float32')]
S[-1] = np.percentile(data0, bkg_per, 0)
lam1_s=lam1_s0*np.ones_like(S)*mbs[0] #intialize sparsity parameters
### Get shape estimates on subset of data ###
if iters0[0] > 0:
for it in range(len(iters0)):
if estimateNoise:
sn_target,sn_std= GetSnPSDArray(data0)#target noise level
else:
sn_target=np.zeros(prod(dims0[1:]))
sn_std=sn_target
MSE_target = np.mean(sn_target**2)
ES=ExponentialSearch(lam1_s) #object to update sparsity parameters
lam1_s=ES.lam
for kk in range(iters0[it]):
# update sparisty parameters
if kk%updateLambdaIntervals==0:
sn=old_div(np.sqrt(Energy0-2*np.sum(np.dot(activity,data0)*S,axis=0)+np.sum(np.dot(np.dot(activity,activity.T),S)*S,axis=0)),dims0[0]) # efficient way to calcuate MSE per pixel
delta_sn=sn-sn_target # noise margin
signcheck=(data0sum-np.dot(np.sum(activity.T,axis=0),S))<0
if PositiveError: #obsolete
delta_sn[signcheck]=-float("inf") # residual should not have negative pixels, so we increase lambda for these pixels
if len(S)==0:
spars=0
else:
spars=np.mean(S>0,axis=1)
temp=repeat(delta_sn.reshape(1,-1),L+adaptBias,axis=0)
if TargetAreaRatio==[]:
cond_decrease=temp>sn_std
cond_increase=temp<-sn_std
else:
if adaptBias:
spars[-1]=old_div((TargetAreaRatio[1]+TargetAreaRatio[0]),2) # ignore sparsity target for background (bias) component
temp2=repeat(spars.reshape(-1,1),len(S[0]),axis=1)
cond_increase=np.logical_or(temp2>TargetAreaRatio[1],temp<-sn_std)
cond_decrease=np.logical_and(temp2<TargetAreaRatio[0],temp>sn_std)
ES.update(cond_decrease,cond_increase)
lam1_s=ES.lam
#Print residual error and additional information
MSE = np.mean(sn**2)
if verbose and L>0:
print(' MSE = {0:.6f}, Target MSE={1:.6f},Sparsity={2:.4f},lam1_s={3:.6f}'.format(MSE,MSE_target,np.mean(spars[:L]),np.mean(lam1_s)))
#add a new component
if (kk%addComponentsIntervals==0) and (kk!=iters0[it]-1):
delta_sn[signcheck]=-float("inf") # residual should not have negative pixels
new_cent=np.argmax(delta_sn) #should I smooth the data a bit first?
MSE_std=np.mean(sn_std**2)
checkNoZero= not((0 in np.sum(activity,axis=1)) and (0 in np.sum(S,axis=1)))
if ((MSE-MSE_target>2*MSE_std) and checkNoZero and (delta_sn[new_cent]>sn_std[new_cent])):
S, activity, mask,centers,boxes,L=addComponent(new_cent,data0,dims0,old_div(R,ds),S, activity, mask,centers,boxes,adaptBias)
new_lam=lam1_s0*np.ones_like(data0[0,:]).reshape(1,-1)
lam1_s=np.insert(lam1_s,0,values=new_lam,axis=0)
ES=ExponentialSearch(lam1_s) #we need to restart exponential search each time we add a component
#apply additional constraints/processing
if SigmaBlur==[]:
if comp_method == 'random':
S = HALS4shape(dataL, S, activity.dot(Lmatrix),mask,lam1_s,lam2_s,adaptBias,inner_iterations)
else:
S = HALS4shape(data0, S, activity,mask,lam1_s,lam2_s,adaptBias,inner_iterations)
else: #obsolete
S=FISTA4shape(data0, S, activity,mask,lam1_s,adaptBias,SigmaBlur,dims0)
if Connected==True:
S=LargestConnectedComponent(S,dims0,adaptBias)
if WaterShed==True:
S=LargestWatershedRegion(S,dims0,adaptBias)
if comp_method == 'random':
NonNegative=False
activity = HALS4activity(dataR, S.dot(Rmatrix.T), activity,NonNegative,lam1_t,lam2_t,dims0,SigmaBlur,inner_iterations)
else:
if comp_method == 'svd':
NonNegative=False
activity = HALS4activity(data0, S, activity,NonNegative,lam1_t,lam2_t,dims0,SigmaBlur,inner_iterations)
if SigmaMask!=[]:
mask=GrowMasks(S,mask,boxes,dims0,adaptBias,SigmaMask)
S, activity, mask,centers,boxes,ES,L=RenormalizeDeleteSort(S, activity, mask,centers,boxes,ES,adaptBias,MedianFilt)
lam1_s=ES.lam
if SmoothBkg==True:
S=SmoothBackground(S,dims0,adaptBias,tuple(old_div(np.array(sig),np.array(ds))))
print('Subsampled iteration',kk,'it=',it,'L=',L)
# use next (smaller) value for temporal downscaling
if it < len(iters0) - 1:
mb = mbs[it + 1]
data0 = data[:len(data) / mb * mb].reshape(-1, mb, prod(dims[1:])).mean(1)
if D==4:
data0 = data0.reshape(len(data0), int(old_div(dims[1], ds[0])), ds[0], int(old_div(dims[2], ds[1])), ds[1],
int(old_div(dims[3], ds[2])), ds[2]).mean(-1).mean(-2).mean(-3)
else:
data0 = data0.reshape(len(data0), int(old_div(dims[1], ds[0])), ds[0], int(old_div(dims[2], ds[1])),
ds[1]).mean(-1).mean(-2)
data0.shape = (len(data0), -1)
activity = ones((L + adaptBias, len(data0))) * activity.mean(1).reshape(-1, 1)
lam1_s=lam1_s*mbs[it+1]/mbs[it]
activity = HALS4activity(data0, S, activity,NonNegative,lam1_t,lam2_t,dims0,SigmaBlur,30)
S, activity, mask,centers,boxes,ES,L=RenormalizeDeleteSort(S, activity, mask,centers,boxes,ES,adaptBias,MedianFilt)
lam1_s=ES.lam
### Stop adding components ###
if L==0: #if no non-background components found, return empty arrays
print('No non-background components found, aborting...')
return [], [], []
if FineTune: ### Upscale Back to full data ##
activity = ones((L + adaptBias, dims[0])) * activity.mean(1).reshape(-1, 1)
data0=data
dims0=dims
if D==4:
S = repeat(repeat(repeat(S.reshape((-1,) + dims0[1:]), ds[0], 1), ds[1], 2), ds[2], 3)
lam1_s= repeat(repeat(repeat(lam1_s.reshape((-1,) + dims0[1:]), ds[0], 1), ds[1], 2), ds[2], 3)
else:
S = repeat(repeat(S.reshape((-1,) + dims0[1:]), ds[0], 1), ds[1], 2)
lam1_s= repeat(repeat(lam1_s.reshape((-1,) + dims0[1:]), ds[0], 1), ds[1], 2)
for dd in range(1,D):
while S.shape[dd]<dims[dd]:
shape_append=np.array(S.shape)
shape_append[dd]=1
S=np.append(S,values=np.take(S,-1,axis=dd).reshape(shape_append),axis=dd)
lam1_s=np.append(lam1_s,values=np.take(lam1_s,-1,axis=dd).reshape(shape_append),axis=dd)
S=S.reshape(L + adaptBias, -1)
lam1_s=lam1_s.reshape(L+ adaptBias,-1)
for ll in range(L):
boxes[ll] = GetBox(centers[ll], R, dims[1:])
temp = zeros(dims[1:])
temp[[slice(*a) for a in boxes[ll]]] = 1
mask[ll] = np.where(temp.ravel())[0]
if FixSupport: #obsolete
for ll in range(L):
lam1_s[ll,S[ll]==0]=float("inf")
ES=ExponentialSearch(lam1_s)
activity = HALS4activity(data0, S, activity,NonNegative,lam1_t,lam2_t,dims0,SigmaBlur, 30)
S, activity, mask,centers,boxes,ES,L=RenormalizeDeleteSort(S, activity, mask,centers,boxes,ES,adaptBias,MedianFilt)
lam1_s=ES.lam
if estimateNoise:
sn_target,sn_std= GetSnPSDArray(data0)#target noise level
else:
sn_target=np.zeros(prod(dims0[1:]))
sn_std=sn_target
MSE_target = np.mean(sn_target**2)
MSE_std=np.mean(sn_std**2)
# MSE = np.mean((data0-np.dot(activity.T,S))**2)
#### Main Loop ####
print('starting main NMF loop')
for kk in range(iters):
lam1_s=ES.lam #update sparsity parameters
if SigmaBlur==[]:
S = HALS4shape(data0, S, activity,mask,lam1_s,lam2_s,adaptBias,inner_iterations)
else: #obsolete
S = FISTA4shape(data0, S, activity,mask,lam1_s,adaptBias,SigmaBlur,dims0)
#apply additional constraints/processing
if Connected==True:
S=LargestConnectedComponent(S,dims0,adaptBias)
if WaterShed==True:
S=LargestWatershedRegion(S,dims0,adaptBias)
if kk==iters-1:
if FinalNonNegative==False:
NonNegative=False
activity = HALS4activity(data0, S, activity,NonNegative,lam1_t,lam2_t,dims0,SigmaBlur,inner_iterations)
if FineTune and Deconvolve:
for ll in range(L):
if np.sum(np.abs(activity[ll])>0)>30: #make sure there is enough signal before we try to deconvolve
activity[ll], _, _, _, _ = deconvolve(activity[ll], penalty=0)
if SigmaMask!=[]:
mask=GrowMasks(S,mask,boxes,dims0,adaptBias,SigmaMask)
S, activity, mask,centers,boxes,ES,L=RenormalizeDeleteSort(S, activity, mask,centers,boxes,ES,adaptBias,MedianFilt)
# Measure MSE and update sparsity parameters
print('main iteration kk=',kk,'L=',L)
if (kk+1)%updateLambdaIntervals==0:
sn=np.sqrt(old_div((Energy-2*np.sum(np.dot(activity,data0)*S,axis=0)+np.sum(np.dot(np.dot(activity,activity.T),S)*S,axis=0)),dims0[0]))
delta_sn=sn-sn_target
MSE = np.mean(sn**2)
signcheck=(datasum-np.dot(np.sum(activity.T,axis=0),S))<0
if PositiveError: #obsolete
delta_sn[signcheck]=-float("inf") # residual should not have negative pixels, so we increase lambda for these pixels
if S==[]:
spars=0
else:
spars=np.mean(S>0,axis=1)
temp=repeat(delta_sn.reshape(1,-1),L+adaptBias,axis=0)
if TargetAreaRatio==[]:
cond_decrease=temp>sn_std
cond_increase=temp<-sn_std
else:
if adaptBias:
spars[-1]=old_div((TargetAreaRatio[1]+TargetAreaRatio[0]),2) # ignore sparsity target for background (bias) component
temp2=repeat(spars.reshape(-1,1),len(S[0]),axis=1)
cond_increase=np.logical_or(temp2>TargetAreaRatio[1],temp<-sn_std)
cond_decrease=np.logical_and(temp2<TargetAreaRatio[0],temp>sn_std)
ES.update(cond_decrease,cond_increase)
lam1_s=ES.lam
if kk<old_div(iters,3): #restart exponential search unless enough iterations have passed
ES=ExponentialSearch(lam1_s)
else:
if not(np.any(cond_increase) or np.any(cond_decrease)):
print('sparsity target reached')
break
if L+adaptBias>1: # if we have more then one component just keep exponitiated grad descent instead
if (kk+1)%updateRhoIntervals==0: #update rho every updateRhoIntervals if we are still not converged
if np.any(spars[:L]<TargetAreaRatio[0]) or np.any(spars[:L]>TargetAreaRatio[1]):
ES.rho=2-old_div(1,(ES.rho))
print('rho=',ES.rho)
ES=ExponentialSearch(lam1_s,rho=ES.rho)
# prinst MSE and other information
if verbose:
print(' MSE = {0:.6f}, Target MSE={1:.6f},Sparsity={2:.4f},lam1_s={3:.6f}'.format(MSE,MSE_target,np.mean(spars[:L]),np.mean(lam1_s)))
if kk == (iters - 1):
print('Maximum iteration limit reached')
MSE_array.append(MSE)
# Some post-processing
S=S.reshape((-1,) + dims[1:])
# S,activity,L=PruneComponents(S,activity,L) #prune "bad" components
if len(S)>1:
S,activity,L=MergeComponents(S,activity,L,threshold_activity=MergeThreshold_activity,threshold_shape=MergeThreshold_shapes,sig=10) #merge very similar components
if not FineTune:
activity = ones((L + adaptBias, dims[0])) * activity.mean(1).reshape(-1, 1) #extract activity from full data
activity=HALS4activity(data, S.reshape((len(S),-1)), activity,NonNegative,lam1_t,lam2_t,dims0,SigmaBlur,iters=30)
return asarray(MSE_array), S, activity
# example to check code works
#T = 1000
#X = 201
#Y = 101
#data = np.random.randn(T, X, Y)
#centers = asarray([[40, 30]])
#data[:, 30:45, 25:33] += 2*np.sin(np.array(range(T))/200).reshape(-1,1,1)*np.ones([T,15,8])
#sig = [300, 300]
#
#MSE_array, shapes, activity, boxes = LocalNMF(
# data, centers, sig, NonNegative=True, verbose=True,lam1_s=0.1,adaptBias=True)
#
#
#import matplotlib.pyplot as plt
#plt.imshow(shapes[0])
#
#for ll in range(len(shapes)):
# print np.mean(shapes[ll]>0)