def load_annotation_file(file_path, end_frame, dims): with h5py.File(file_path, 'r') as f: # shape(frame, neuron_num), temporal fluorescence traces if 'C' in f.keys(): C = np.array(f.get('C')).T[:, :end_frame] neuron_num = C.shape[0] else: C = None if 'C_raw' in f.keys(): C_raw = np.array(f.get('C_raw')).T[:, :end_frame] neuron_num = C_raw.shape[0] else: C_raw = None # C = np.array(f.get('C_raw')).T[:, :end_frame] # C = np.where(C < 0, 0, C) # neuron_num = C.shape[0] # shape(neuron_num, height * width), spatial components A = np.array(f.get('A'), dtype='float32').T.reshape(dims[0], dims[1], neuron_num) A = csc_matrix(A.transpose(1, 0, 2).reshape(-1, neuron_num)) estimates = Estimates(A=A, b=None, C=C, f=None, R=None, dims=dims) estimates.C_raw = C_raw return estimates
opts = params.CNMFParams(params_dict=params_dict) cnm = load_CNMF(fname_new[:-5] + '_cnmf_gsig.hdf5') # %% prepare ground truth masks gt_file = os.path.join( os.path.split(fname_new)[0], os.path.split(fname_new)[1][:-4] + 'match_masks.npz') with np.load(gt_file, encoding='latin1') as ld: print(ld.keys()) Cn_orig = ld['Cn'] gt_estimate = Estimates(A=scipy.sparse.csc_matrix(ld['A_gt'][()]), b=ld['b_gt'], C=ld['C_gt'], f=ld['f_gt'], R=ld['YrA_gt'], dims=(ld['d1'], ld['d2'])) min_size_neuro = 3 * 2 * np.pi max_size_neuro = (2 * params_dict['gSig'][0])**2 * np.pi gt_estimate.threshold_spatial_components(maxthr=0.2, dview=dview) gt_estimate.remove_small_large_neurons(min_size_neuro, max_size_neuro) _ = gt_estimate.remove_duplicates(predictions=None, r_values=None, dist_thr=0.1, min_dist=10, thresh_subset=0.6) print(gt_estimate.A_thr.shape) for gr_snr in SNRs_grid:
C_gt = ld['C_gt'] Cn_orig = ld['Cn'] if ds_factor > 1: A_gt2 = np.concatenate([ cv2.resize(A_gt[:, fr_].reshape(dims_or, order='F'), cnm.dims[::-1]).reshape(-1, order='F')[:, None] for fr_ in range(A_gt.shape[-1]) ], axis=1) Cn_orig = cv2.resize(Cn_orig, cnm.dims[::-1]) else: A_gt2 = A_gt.copy() gt_estimate = Estimates(A=scipy.sparse.csc_matrix(A_gt2), b=None, C=C_gt, f=None, R=None, dims=cnm.dims) gt_estimate.threshold_spatial_components(maxthr=global_params['max_thr'], dview=None) gt_estimate.remove_small_large_neurons(min_size_neuro, max_size_neuro) _ = gt_estimate.remove_duplicates(predictions=None, r_values=None, dist_thr=0.1, min_dist=10, thresh_subset=0.6) print(gt_estimate.A.shape) # %% compute performance and plot against consensus annotations tp_gt, tp_comp, fn_gt, fp_comp, performance_cons_off = compare_components( gt_estimate,
plt.figure() plt.plot(vpy.estimates['recons_signal'][n][:2000]) from caiman.source_extraction.cnmf.estimates import Estimates import scipy estimates = vpy.estimates.copy() A = np.array(estimates['spatial_filter']).transpose([1,2,0]).reshape((-1, len(estimates['spatial_filter'])),order='F') A = A / A.max(axis=0) b = np.zeros((A.shape[0],2)) A = scipy.sparse.csc_matrix(A) b = scipy.sparse.csc_matrix(b) C = np.array(estimates['t_rec']) f = np.zeros((2, C.shape[1])) R = np.array(estimates['t']) - C est = Estimates(A=A, C=C, b=b, f=f, R=R, dims=(100,100)) est.YrA = R #est.plot_contours(img=summary_image[:,:,2]) # now load the file Yr, dims, T = cm.load_memmap(fname_new) images = np.reshape(Yr.T, [T] + list(dims), order='F') est.dview = dview #est.view_components(img=summary_image[:,:,2]) est.play_movie(imgs=images, magnification=4) est = np.load('/home/nel/data/voltage_data/volpy_paper/reconstructed/estimates.npz',allow_pickle=True)['arr_0'].item() fnames = ['/home/nel/data/voltage_data/volpy_paper/memory/403106_3min_10000._rig__d1_512_d2_128_d3_1_order_F_frames_10000_.mmap']
def OnACID_A_init(fr, fnames, out, hfile, epochs=2): # %% set up some parameters decay_time = .4 # approximate length of transient event in seconds gSig = (4, 4) # expected half size of neurons p = 1 # order of AR indicator dynamics thresh_CNN_noisy = 0.8 #0.65 # CNN threshold for candidate components gnb = 2 # number of background components init_method = 'cnmf' # initialization method min_SNR = 2.5 # signal to noise ratio for accepting a component rval_thr = 0.8 # space correlation threshold for accepting a component ds_factor = 1 # spatial downsampling factor, newImg=img/ds_factor(increases speed but may lose some fine structure) # K = 25 # number of components per patch patch_size = 32 # size of patch stride = 3 # amount of overlap between patches max_num_added = 5 max_comp_update_shape = np.inf update_num_comps = False gSig = tuple(np.ceil( np.array(gSig) / ds_factor).astype('int')) # recompute gSig if downsampling is involved mot_corr = True # flag for online motion correction pw_rigid = False # flag for pw-rigid motion correction (slower but potentially more accurate) max_shifts_online = np.ceil(10. / ds_factor).astype( 'int') # maximum allowed shift during motion correction sniper_mode = False # use a CNN to detect new neurons (o/w space correlation) # set up some additional supporting parameters needed for the algorithm # (these are default values but can change depending on dataset properties) init_batch = 500 # number of frames for initialization (presumably from the first file) K = 2 # initial number of components show_movie = False # show the movie as the data gets processed print("Frame rate: {}".format(fr)) params_dict = { 'fr': fr, 'fnames': fnames, 'decay_time': decay_time, 'gSig': gSig, 'gnb': gnb, 'p': p, 'min_SNR': min_SNR, 'rval_thr': rval_thr, 'ds_factor': ds_factor, 'nb': gnb, 'motion_correct': mot_corr, 'normalize': True, 'sniper_mode': sniper_mode, 'K': K, 'use_cnn': False, 'epochs': epochs, 'max_shifts_online': max_shifts_online, 'pw_rigid': pw_rigid, 'min_num_trial': 10, 'show_movie': show_movie, 'save_online_movie': False, "max_num_added": max_num_added, "max_comp_update_shape": max_comp_update_shape, "update_num_comps": update_num_comps, "dist_shape_update": update_num_comps, 'init_batch': init_batch, 'init_method': init_method, 'rf': patch_size // 2, 'stride': stride, 'thresh_CNN_noisy': thresh_CNN_noisy } opts = CNMFParams(params_dict=params_dict) with h5py.File(hfile, 'r') as hf: ests = Estimates(A=load_A(hf)) cnm = online_cnmf.OnACID(params=opts, estimates=ests) cnm.estimates = ests cnm.fit_online() cnm.save(out)