def optimize_combine_1(uniques: sparse.csr_matrix, times_uniques: list, dims: tuple, Params: dict, filename_GT: str): '''Optimize 1 post-processing parameter: "cons". Start after the first COM merging. The outputs are the recall, precision, and F1 calculated using all values in "list_cons". Inputs: uniques (sparse.csr_matrix of float32, shape = (n,Lx*Ly)): the neuron masks to be merged. times_uniques (list of 1D numpy.array): indices of frames when the neuron is active. dims (tuple of int, shape = (2,)): the lateral shape of the image. Params (dict): Ranges of post-processing parameters to optimize over. Params['avgArea']: The typical neuron area (unit: pixels). Params['thresh_mask']: (float) Threashold to binarize the real-number mask. Params['thresh_COM']: (float or int) Threshold of COM distance (unit: pixels) used for the second COM-based merging. Params['thresh_IOU']: (float) Threshold of IoU used for merging neurons. Params['list_cons']: (list) Range of minimum number of consecutive frames that a neuron should be active for. filename_GT (str): file name of the GT masks. The GT masks are stored in a ".mat" file, and dataset "GTMasks_2" is the GT masks (shape = (Ly0*Lx0,n) when saved in MATLAB). Outputs: Recall_k (1D numpy.array of float): Recall for all cons. Precision_k (1D numpy.array of float): Precision for all cons. F1_k (1D numpy.array of float): F1 for all cons. ''' avgArea = Params['avgArea'] thresh_mask = Params['thresh_mask'] thresh_COM = Params['thresh_COM'] thresh_IOU = Params['thresh_IOU'] thresh_consume = (1 + thresh_IOU) / 2 list_cons = Params['list_cons'] # second merge neurons with close COM. groupedneurons, times_groupedneurons = group_neurons(uniques, \ thresh_COM, thresh_mask, (dims[1], dims[2]), times_uniques, useMP=False) # Merge neurons with high IoU. piecedneurons_1, times_piecedneurons_1 = piece_neurons_IOU(groupedneurons, \ thresh_mask, thresh_IOU, times_groupedneurons) # Merge neurons with high consume ratio. piecedneurons, times_piecedneurons = piece_neurons_consume(piecedneurons_1, \ avgArea, thresh_mask, thresh_consume, times_piecedneurons_1) masks_final_2 = piecedneurons times_final = [np.unique(x) for x in times_piecedneurons] data_GT = loadmat(filename_GT) GTMasks_2 = data_GT['GTMasks_2'].transpose() # Search for optimal "cons" used to refine segmented neurons. Recall_k, Precision_k, F1_k = refine_seperate_multi(GTMasks_2, \ masks_final_2, times_final, list_cons, thresh_mask, display=False) return Recall_k, Precision_k, F1_k
def final_merge(tuple_temp, Params): '''An extra round of merging at the end of online processing, to merge all previously detected neurons according their to IoU and consume ratio, like in batch mode. The output are "Masks_2b", a 2D sparse matrix of the final segmented neurons, and "times_cons", a list of indices of frames when the final neuron is active. Inputs: tuple_temp (tuple, shape = (5,)): Segmented masks with statistics. Params (dict): Parameters for post-processing. Params['thresh_mask']: Threashold to binarize the real-number mask. Params['thresh_IOU']: Threshold of IOU used for merging neurons. Params['thresh_consume']: Threshold of consume ratio used for merging neurons. Params['cons']: Minimum number of consecutive frames that a neuron should be active for. Params['avgArea']: The typical neuron area (unit: pixels). Outputs: Masks_2b (sparse.csr_matrix of bool): the final segmented binary neuron masks after consecutive refinement. times_cons (list of 1D numpy.array): indices of frames when the final neuron is active. ''' _, masks, times, _, _ = tuple_temp if len(masks)==0: # If no masks is found, the output is tuple_temp return tuple_temp # if area.ndim==0: # area = np.expand_dims(area, axis=0) thresh_mask = Params['thresh_mask'] thresh_IOU = Params['thresh_IOU'] thresh_consume = Params['thresh_consume'] cons = Params['cons'] avgArea = Params['avgArea'] masks = sparse.vstack(masks) masks_1, times_1 = piece_neurons_IOU(masks, thresh_mask, thresh_IOU, times) masks_final_2, times_2 = piece_neurons_consume(masks_1, avgArea, thresh_mask, thresh_consume, times_1) Masks_2b, times_final = refine_seperate(masks_final_2, times_2, cons, thresh_mask) return Masks_2b, times_final
def suns_online(filename_video, filename_CNN, Params_pre, Params_post, dims, \ frames_init, merge_every, batch_size_init=1, useSF=True, useTF=True, useSNR=True, \ med_subtract=False, update_baseline=False, \ useWT=False, show_intermediate=True, prealloc=True, display=True, useMP=True, p=None): '''The complete SUNS online procedure. It uses the trained CNN model from "filename_CNN" and the optimized hyper-parameters in "Params_post" to process the video "Exp_ID" in "dir_video" Inputs: filename_video (str): The path of the file of the input raw video. The file must be a ".h5" file, with dataset "mov" being the input video (shape = (T0,Lx0,Ly0)). filename_CNN (str): The path of the trained CNN model. Params_pre (dict): Parameters for pre-processing. Params_pre['gauss_filt_size'] (float): The standard deviation of the spatial Gaussian filter in pixels Params_pre['Poisson_filt'] (1D numpy.ndarray of float32): The temporal filter kernel Params_pre['num_median_approx'] (int): Number of frames used to compute the median and median-based standard deviation Params_pre['nn'] (int): Number of frames at the beginning of the video to be processed. The remaining video is not considered a part of the input video. Params_post (dict): Parameters for post-processing. Params_post['minArea']: Minimum area of a valid neuron mask (unit: pixels). Params_post['avgArea']: The typical neuron area (unit: pixels). Params_post['thresh_pmap']: The probablity threshold. Values higher than thresh_pmap are active pixels. It is stored in uint8, so it should be converted to float32 before using. Params_post['thresh_mask']: Threashold to binarize the real-number mask. Params_post['thresh_COM0']: Threshold of COM distance (unit: pixels) used for the first COM-based merging. Params_post['thresh_COM']: Threshold of COM distance (unit: pixels) used for the second COM-based merging. Params_post['thresh_IOU']: Threshold of IOU used for merging neurons. Params_post['thresh_consume']: Threshold of consume ratio used for merging neurons. Params_post['cons']: Minimum number of consecutive frames that a neuron should be active for. dims (tuplel of int, shape = (2,)): lateral dimension of the raw video. frames_init (int): Number of frames used for initialization. merge_every (int): SUNS online merge the newly segmented frames every "merge_every" frames. batch_size_init (int, default to 1): batch size of CNN inference for initialization frames. useSF (bool, default to True): True if spatial filtering is used. useTF (bool, default to True): True if temporal filtering is used. useSNR (bool, default to True): True if pixel-by-pixel SNR normalization filtering is used. med_subtract (bool, default to False): True if the spatial median of every frame is subtracted before temporal filtering. Can only be used when spatial filtering is not used. update_baseline (bool, default to False): True if the median and median-based std is updated every "frames_init" frames. useWT (bool, default to False): Indicator of whether watershed is used. show_intermediate (bool, default to True): Indicator of whether consecutive frame requirement is applied to screen neurons after every update. prealloc (bool, default to True): True if pre-allocate memory space for large variables. Achieve faster speed at the cost of higher memory occupation. display (bool, default to True): Indicator of whether to show intermediate information useMP (bool, defaut to True): indicator of whether multiprocessing is used to speed up. p (multiprocessing.Pool, default to None): Outputs: Masks (3D numpy.ndarray of bool, shape = (n,Lx0,Ly0)): the final segmented masks. Masks_2 (scipy.csr_matrix of bool, shape = (n,Lx0*Ly0)): the final segmented masks in the form of sparse matrix. time_total (list of float, shape = (3,)): the total time spent for initalization, online processing, and total processing time_frame (list of float, shape = (3,)): the average time spent on every frame for initalization, online processing, and total processing ''' if display: start = time.time() (Lx, Ly) = dims # zero-pad the lateral dimensions to multiples of 8, suitable for CNN rowspad = math.ceil(Lx / 8) * 8 colspad = math.ceil(Ly / 8) * 8 dimspad = (rowspad, colspad) Poisson_filt = Params_pre['Poisson_filt'] gauss_filt_size = Params_pre['gauss_filt_size'] nn = Params_pre['nn'] leng_tf = Poisson_filt.size leng_past = 2 * leng_tf # number of past frames stored for temporal filtering list_time_per = np.zeros(nn) # Load CNN model fff = get_shallow_unet() fff.load_weights(filename_CNN) # run CNN inference once to warm up init_imgs = np.zeros((batch_size_init, rowspad, colspad, 1), dtype='float32') init_masks = np.zeros((batch_size_init, rowspad, colspad, 1), dtype='uint8') fff.evaluate(init_imgs, init_masks, batch_size=batch_size_init) del init_imgs, init_masks # load optimal post-processing parameters minArea = Params_post['minArea'] avgArea = Params_post['avgArea'] # thresh_pmap = Params_post['thresh_pmap'] thresh_mask = Params_post['thresh_mask'] thresh_COM0 = Params_post['thresh_COM0'] thresh_COM = Params_post['thresh_COM'] thresh_IOU = Params_post['thresh_IOU'] thresh_consume = Params_post['thresh_consume'] cons = Params_post['cons'] # thresh_pmap_float = (Params_post['thresh_pmap']+1.5)/256 thresh_pmap_float = (Params_post['thresh_pmap'] + 1) / 256 # for published version # Spatial filtering preparation if useSF == True: # lateral dimensions slightly larger than the raw video but faster for FFT rows1 = cv2.getOptimalDFTSize(rowspad) cols1 = cv2.getOptimalDFTSize(colspad) # if the learned 2D and 3D wisdom files have been saved, load them. # Otherwise, learn wisdom later Length_data2 = str((rows1, cols1)) cc2 = load_wisdom_txt('wisdom\\' + Length_data2) Length_data3 = str((frames_init, rows1, cols1)) cc3 = load_wisdom_txt('wisdom\\' + Length_data3) if cc3: pyfftw.import_wisdom(cc3) # mask for spatial filter mask2 = plan_mask2(dims, (rows1, cols1), gauss_filt_size) # FFT planning (bb, bf, fft_object_b, fft_object_c) = plan_fft(frames_init, (rows1, cols1), prealloc) else: (mask2, bf, fft_object_b, fft_object_c) = (None, None, None, None) bb = np.zeros((frames_init, rowspad, colspad), dtype='float32') # Temporal filtering preparation frames_initf = frames_init - leng_tf + 1 if useTF == True: if prealloc: # past frames stored for temporal filtering past_frames = np.ones((leng_past, rowspad, colspad), dtype='float32') else: past_frames = np.zeros((leng_past, rowspad, colspad), dtype='float32') else: past_frames = None if prealloc: # Pre-allocate memory for some future variables med_frame2 = np.ones((rowspad, colspad, 2), dtype='float32') video_input = np.ones((frames_initf, rowspad, colspad), dtype='float32') pmaps_b_init = np.ones((frames_initf, Lx, Ly), dtype='uint8') frame_SNR = np.ones(dimspad, dtype='float32') pmaps_b = np.ones(dims, dtype='uint8') if update_baseline: video_tf_past = np.ones((frames_init, rowspad, colspad), dtype='float32') else: med_frame2 = np.zeros((rowspad, colspad, 2), dtype='float32') video_input = np.zeros((frames_initf, rowspad, colspad), dtype='float32') pmaps_b_init = np.zeros((frames_initf, Lx, Ly), dtype='uint8') frame_SNR = np.zeros(dimspad, dtype='float32') pmaps_b = np.zeros(dims, dtype='uint8') if update_baseline: video_tf_past = np.zeros((frames_init, rowspad, colspad), dtype='float32') if display: time_init = time.time() print('Parameter initialization time: {} s'.format(time_init - start)) # %% Load raw video h5_img = h5py.File(filename_video, 'r') video_raw = np.array(h5_img['mov']) h5_img.close() nframes = video_raw.shape[0] nframesf = nframes - leng_tf + 1 bb[:, :Lx, :Ly] = video_raw[:frames_init] if display: time_load = time.time() print('Load data: {} s'.format(time_load - time_init)) # %% Actual processing starts after the video is loaded into memory # Initialization using the first "frames_init" frames print('Initialization of algorithms using the first {} frames'.format( frames_init)) if display: start_init = time.time() med_frame3, segs_all, recent_frames = init_online( bb, dims, video_input, pmaps_b_init, fff, thresh_pmap_float, Params_post, \ med_frame2, mask2, bf, fft_object_b, fft_object_c, Poisson_filt, \ useSF=useSF, useTF=useTF, useSNR=useSNR, med_subtract=med_subtract, \ useWT=useWT, batch_size_init=batch_size_init, p=p) if useTF == True: past_frames[:leng_tf] = recent_frames tuple_temp = merge_complete(segs_all[:frames_initf], dims, Params_post) if show_intermediate: Masks_2 = select_cons(tuple_temp) if display: end_init = time.time() time_init = end_init - start_init time_frame_init = time_init / (frames_initf) * 1000 print('Initialization time: {:6f} s, {:6f} ms/frame'.format( time_init, time_frame_init)) if display: start_online = time.time() # Spatial filtering preparation for online processing. # Attention: this part counts to the total time if useSF: if cc2: pyfftw.import_wisdom(cc2) (bb, bf, fft_object_b, fft_object_c) = plan_fft2((rows1, cols1)) else: (bf, fft_object_b, fft_object_c) = (None, None, None) bb = np.zeros(dimspad, dtype='float32') print('Start frame by frame processing') # %% Online processing for the following frames current_frame = leng_tf + 1 t_merge = frames_initf for t in range(frames_initf, nframesf): if display: start_frame = time.time() # load the current frame bb[:Lx, :Ly] = video_raw[t + leng_tf - 1] bb[Lx:, :] = 0 bb[:, Ly:] = 0 # PreProcessing frame_SNR, frame_tf = preprocess_online(bb, dimspad, med_frame3, frame_SNR, \ past_frames[current_frame-leng_tf:current_frame], mask2, bf, fft_object_b, \ fft_object_c, Poisson_filt, useSF=useSF, useTF=useTF, useSNR=useSNR, \ med_subtract=med_subtract, update_baseline=update_baseline) if update_baseline: t_past = (t - frames_initf) % frames_init video_tf_past[t_past] = frame_tf if t_past == frames_init - 1: # update median and median-based standard deviation every "frames_init" frames if useSNR: med_frame3 = SNR_normalization(video_tf_past, med_frame2, (rowspad, colspad), 1, display=False) else: med_frame3 = median_normalization(video_tf_past, med_frame2, (rowspad, colspad), 1, display=False) # CNN inference frame_prob = CNN_online(frame_SNR, fff, dims) # first step of post-processing segs = separate_neuron_online(frame_prob, pmaps_b, thresh_pmap_float, minArea, avgArea, useWT) segs_all.append(segs) # temporal merging 1: combine neurons with COM distance smaller than thresh_COM0 if ((t + 1 - t_merge) == merge_every) or (t == nframesf - 1): # uniques, times_uniques = unique_neurons1_simp(segs_all[t_merge:], thresh_COM0) # minArea, totalmasks, neuronstate, COMs, areas, probmapID = segs_results( segs_all[t_merge:]) uniques, times_uniques = unique_neurons2_simp(totalmasks, neuronstate, COMs, \ areas, probmapID, minArea=0, thresh_COM0=thresh_COM0, useMP=useMP) # temporal merging 2: combine neurons with COM distance smaller than thresh_COM if ((t - 0 - t_merge) == merge_every) or (t == nframesf - 1): if uniques.size: groupedneurons, times_groupedneurons = \ group_neurons(uniques, thresh_COM, thresh_mask, dims, times_uniques, useMP=useMP) # temporal merging 3: combine neurons with IoU larger than thresh_IOU if ((t - 1 - t_merge) == merge_every) or (t == nframesf - 1): if uniques.size: piecedneurons_1, times_piecedneurons_1 = \ piece_neurons_IOU(groupedneurons, thresh_mask, thresh_IOU, times_groupedneurons) # temporal merging 4: combine neurons with conumse ratio larger than thresh_consume if ((t - 2 - t_merge) == merge_every) or (t == nframesf - 1): if uniques.size: piecedneurons, times_piecedneurons = \ piece_neurons_consume(piecedneurons_1, avgArea, thresh_mask, thresh_consume, times_piecedneurons_1) # masks of new neurons masks_add = piecedneurons # indices of frames when the neurons are active times_add = [ np.unique(x) + t_merge for x in times_piecedneurons ] # Refine neurons using consecutive occurence if masks_add.size: # new real-number masks masks_add = [x for x in masks_add] # new binary masks Masksb_add = [(x >= x.max() * thresh_mask).astype('float') for x in masks_add] # areas of new masks area_add = np.array([x.nnz for x in Masksb_add]) # indicators of whether the new masks satisfy consecutive frame requirement have_cons_add = refine_seperate_cons_online( times_add, cons) else: Masksb_add = [] area_add = np.array([]) have_cons_add = np.array([]) else: # does not find any active neuron Masksb_add = [] masks_add = [] times_add = times_uniques area_add = np.array([]) have_cons_add = np.array([]) tuple_add = (Masksb_add, masks_add, times_add, area_add, have_cons_add) # temporal merging 5: merge newly found neurons within the recent "merge_every" frames with existing neurons if ((t - 3 - t_merge) == merge_every) or (t == nframesf - 1): tuple_temp = merge_2(tuple_temp, tuple_add, dims, Params_post) t_merge += merge_every if show_intermediate: Masks_2 = select_cons(tuple_temp) current_frame += 1 # Update the stored latest frames when it runs out: move them "leng_tf" ahead if current_frame > leng_past: current_frame = leng_tf + 1 past_frames[:leng_tf] = past_frames[-leng_tf:] if display: end_frame = time.time() list_time_per[t] = end_frame - start_frame if t % 1000 == 0: print('{} frames has been processed'.format(t)) if not show_intermediate: Masks_2 = select_cons(tuple_temp) # final result. Masks_2 is a 2D sparse matrix of the segmented neurons if len(Masks_2): Masks_2 = sparse.vstack(Masks_2) else: Masks_2 = sparse.csc_matrix((0, dims[0] * dims[1])) if display: end_online = time.time() time_online = end_online - start_online time_frame_online = time_online / (nframesf - frames_initf) * 1000 print('Online time: {:6f} s, {:6f} ms/frame'.format( time_online, time_frame_online)) # Save total processing time, and average processing time per frame if display: end_final = time.time() time_all = end_final - start_init time_frame_all = time_all / nframes * 1000 print('Total time: {:6f} s, {:6f} ms/frame'.format( time_all, time_frame_all)) time_total = np.array([time_init, time_online, time_all]) time_frame = np.array( [time_frame_init, time_frame_online, time_frame_all]) else: time_total = np.zeros((3, )) time_frame = np.zeros((3, )) # convert to a 3D array of the segmented neurons Masks = np.reshape(Masks_2.toarray(), (Masks_2.shape[0], Lx, Ly)).astype('bool') return Masks, Masks_2, time_total, time_frame
def merge_complete(segs, dims, Params): '''Temporally merge segmented masks in a few frames. The output are the merged neuron masks and their statistics (acitve frame indices, areas, whether satisfy consecutive activation). Inputs: segs (list): A list of segmented masks for every frame with statistics. dims (tuple of int, shape = (2,)): lateral dimension of the image. Params_post (dict): Parameters for post-processing. Params['thresh_mask']: Threashold to binarize the real-number mask. Params['thresh_COM0']: Threshold of COM distance (unit: pixels) used for the first COM-based merging. Params['thresh_COM']: Threshold of COM distance (unit: pixels) used for the second COM-based merging. Params['thresh_IOU']: Threshold of IOU used for merging neurons. Params['thresh_consume']: Threshold of consume ratio used for merging neurons. Params['cons']: Minimum number of consecutive frames that a neuron should be active for. Outputs: Masks_2 (list of sparse.csr_matrix of bool, shape = (1,Lx*Ly)): 2D representation of each segmented binary mask. masks_final_2 (list of sparse.csr_matrix of float32, shape = (1,Lx*Ly)): 2D representation of each segmented real-number mask. times_final (list of 1D numpy.ndarray of int): indices of frames when each neuron is active. area (1D numpy.ndarray of float32): areas of each mask. have_cons (1D numpy.ndarray of bool): indices of whether each neuron satisfy consecutive frame requirement. The above outputs are often grouped into a tuple (shape = (5,)): Segmented masks with statistics after update. ''' avgArea = Params['avgArea'] thresh_mask = Params['thresh_mask'] thresh_COM0 = Params['thresh_COM0'] thresh_COM = Params['thresh_COM'] thresh_IOU = Params['thresh_IOU'] thresh_consume = Params['thresh_consume'] cons = Params['cons'] totalmasks, neuronstate, COMs, areas, probmapID = segs_results(segs) # Initally merge neurons with close COM. uniques, times_uniques = unique_neurons2_simp(totalmasks, neuronstate, COMs, \ areas, probmapID, minArea=0, thresh_COM0=thresh_COM0) if uniques.size: # Further merge neurons with close COM. groupedneurons, times_groupedneurons = \ group_neurons(uniques, thresh_COM, thresh_mask, dims, times_uniques) # Merge neurons with high IoU. piecedneurons_1, times_piecedneurons_1 = \ piece_neurons_IOU(groupedneurons, thresh_mask, thresh_IOU, times_groupedneurons) # Merge neurons with high consume ratio. piecedneurons, times_piecedneurons = \ piece_neurons_consume(piecedneurons_1, avgArea, thresh_mask, thresh_consume, times_piecedneurons_1) # %% Final result masks_final_2 = piecedneurons times_final = [np.unique(x) for x in times_piecedneurons] # %% Refine neurons using consecutive occurence if masks_final_2.size: masks_final_2 = [x for x in masks_final_2] Masks_2 = [(x >= x.max() * thresh_mask).astype('float') for x in masks_final_2] area = np.array([x.nnz for x in Masks_2]) have_cons = refine_seperate_cons_online(times_final, cons) else: Masks_2 = [] area = np.array([]) have_cons = np.array([]) else: Masks_2 = [] masks_final_2 = [] times_final = times_uniques area = np.array([]) have_cons = np.array([]) return Masks_2, masks_final_2, times_final, area, have_cons
def merge_complete_nocons(uniques, times_uniques, dims, Params): '''Temporally merge segmented masks in a few frames. Used for parameter optimization. Ignore consecutive frame requirement in this function. The output are the merged neuron masks and their statistics (acitve frame indices, areas, whether satisfy consecutive activation). Inputs: uniques (sparse.csr_matrix): the neuron masks after the first COM merging. times_uniques (list of 1D numpy.array): indices of frames when the neuron is active. dims (tuple of int, shape = (2,)): lateral dimension of the image. Params_post (dict): Parameters for post-processing. Params['thresh_mask']: Threashold to binarize the real-number mask. Params['thresh_COM0']: Threshold of COM distance (unit: pixels) used for the first COM-based merging. Params['thresh_COM']: Threshold of COM distance (unit: pixels) used for the second COM-based merging. Params['thresh_IOU']: Threshold of IOU used for merging neurons. Params['thresh_consume']: Threshold of consume ratio used for merging neurons. Outputs: Masks_2 (list of sparse.csr_matrix of bool, shape = (1,Lx*Ly)): 2D representation of each segmented binary mask. masks_final_2 (list of sparse.csr_matrix of float32, shape = (1,Lx*Ly)): 2D representation of each segmented real-number mask. times_final (list of 1D numpy.ndarray of int): indices of frames when each neuron is active. area (1D numpy.ndarray of float32): areas of each mask. have_cons (1D numpy.ndarray of bool): indices of whether each neuron satisfy consecutive frame requirement. The above outputs are often grouped into a tuple (shape = (5,)): Segmented masks with statistics after update. ''' avgArea = Params['avgArea'] thresh_mask = Params['thresh_mask'] thresh_COM = Params['thresh_COM'] thresh_IOU = Params['thresh_IOU'] thresh_consume = Params['thresh_consume'] if uniques.size: # Further merge neurons with close COM. groupedneurons, times_groupedneurons = \ group_neurons(uniques, thresh_COM, thresh_mask, dims, times_uniques) # Merge neurons with high IoU. piecedneurons_1, times_piecedneurons_1 = \ piece_neurons_IOU(groupedneurons, thresh_mask, thresh_IOU, times_groupedneurons) # Merge neurons with high consume ratio. piecedneurons, times_piecedneurons = \ piece_neurons_consume(piecedneurons_1, avgArea, thresh_mask, thresh_consume, times_piecedneurons_1) # %% Final result masks_final_2 = piecedneurons times_final = [np.unique(x) for x in times_piecedneurons] # %% Refine neurons using consecutive occurence if masks_final_2.size: masks_final_2 = [x for x in masks_final_2] Masks_2 = [(x >= x.max() * thresh_mask).astype('float') for x in masks_final_2] area = np.array([x.nnz for x in Masks_2]) # Since this function is used for parameter optimization, searching "cons" will be # done in the next step. Here, we just assume all masks are valid neurons. have_cons = np.ones(len(masks_final_2), dtype='bool') else: Masks_2 = [] area = np.array([]) have_cons = np.array([]) else: Masks_2 = [] masks_final_2 = [] times_final = times_uniques area = np.array([]) have_cons = np.array([]) return Masks_2, masks_final_2, times_final, area, have_cons
def complete_segment(pmaps: np.ndarray, Params: dict, useMP=True, useWT=False, display=False, p=None): '''Complete post-processing procedure. This can be run after or before probablity thresholding, depending on whether Params['thresh_pmap'] is None. It first thresholds the "pmaps" (if Params['thresh_pmap'] is not None) into binary array, then seperates the active pixels into connected regions, disgards regions smaller than Params['minArea'], uses optional watershed (if useWT=True) to further segment regions larger than Params['avgArea'], merge the regions from different frames with close COM, large IoU, or large consume ratio, and finally selects masks that are active for at least Params['cons'] frames. The output are "Masks_2", a 2D sparse matrix of the final segmented neurons, and "times_cons", a list of indices of frames when the final neuron is active. Inputs: pmaps (3D numpy.ndarray of uint8, shape = (nframes,Lx,Ly)): the probability map obtained after CNN inference. If Params['thresh_pmap']==None, pmaps must be previously thresholded. Params (dict): Parameters for post-processing. Params['minArea']: Minimum area of a valid neuron mask (unit: pixels). Params['avgArea']: The typical neuron area (unit: pixels). Params['thresh_pmap']: The probablity threshold. Values higher than thresh_pmap are active pixels. if Params['thresh_pmap']==None, then thresholding is not performed. This is used when thresholding is done before this function. Params['thresh_mask']: Threashold to binarize the real-number mask. Params['thresh_COM0']: Threshold of COM distance (unit: pixels) used for the first COM-based merging. Params['thresh_COM']: Threshold of COM distance (unit: pixels) used for the second COM-based merging. Params['thresh_IOU']: Threshold of IOU used for merging neurons. Params['thresh_consume']: Threshold of consume ratio used for merging neurons. Params['cons']: Minimum number of consecutive frames that a neuron should be active for. useMP (bool, defaut to True): indicator of whether multiprocessing is used to speed up. useWT (bool, default to False): Indicator of whether watershed is used. display (bool, default to False): Indicator of whether to show intermediate information p (multiprocessing.Pool, default to None): Outputs: Masks_2 (sparse.csr_matrix of bool): the final segmented binary neuron masks after consecutive refinement. times_cons (list of 1D numpy.array): indices of frames when the final neuron is active. ''' dims = pmaps.shape (nframes, Lx, Ly) = dims minArea = Params['minArea'] avgArea = Params['avgArea'] thresh_pmap = Params['thresh_pmap'] thresh_mask = Params['thresh_mask'] thresh_COM0 = Params['thresh_COM0'] thresh_COM = Params['thresh_COM'] thresh_IOU = Params['thresh_IOU'] thresh_consume = Params['thresh_consume'] cons = Params['cons'] start_all = time.time() # Segment neuron masks from each frame of probability map start = time.time() if useMP: segs = p.starmap(separate_neuron, [(frame, thresh_pmap, minArea, avgArea, useWT) for frame in pmaps], chunksize=1) else: segs = [ separate_neuron(frame, thresh_pmap, minArea, avgArea, useWT) for frame in pmaps ] end = time.time() num_neurons = sum([x[1].size for x in segs]) if display: print('{:25s}: Used {:9.6f} s, {:9.6f} ms/frame, '\ .format('separate Neurons', end-start,(end-start)/nframes*1000), '{:6d} segmented neurons.'.format(num_neurons)) if num_neurons == 0: print('No masks found. Please lower minArea or thresh_pmap.') Masks_2 = sparse.csc_matrix((0, Lx * Ly), dtype='bool') times_cons = [] else: # find active neurons start = time.time() # Initally merge neurons with close COM. totalmasks, neuronstate, COMs, areas, probmapID = segs_results(segs) uniques, times_uniques = unique_neurons2_simp(totalmasks, neuronstate, COMs, \ areas, probmapID, minArea=0, thresh_COM0=thresh_COM0) end_unique = time.time() if display: print('{:25s}: Used {:9.6f} s, {:9.6f} ms/frame, '\ .format('unique_neurons1', end_unique - start, (end_unique - start) / nframes * 1000),\ '{:6d} segmented neurons.'.format(len(times_uniques))) # Further merge neurons with close COM. groupedneurons, times_groupedneurons = \ group_neurons(uniques, thresh_COM, thresh_mask, (dims[1], dims[2]), times_uniques) end_COM = time.time() if display: print('{:25s}: Used {:9.6f} s, {:9.6f} ms/frame, '\ .format('group_neurons', end_COM - end_unique, (end_COM - end_unique) / nframes * 1000),\ '{:6d} segmented neurons.'.format(len(times_groupedneurons))) # Merge neurons with high IoU. piecedneurons_1, times_piecedneurons_1 = \ piece_neurons_IOU(groupedneurons, thresh_mask, thresh_IOU, times_groupedneurons) end_IOU = time.time() if display: print('{:25s}: Used {:9.6f} s, {:9.6f} ms/frame, '\ .format('piece_neurons_IOU', end_IOU - end_COM, (end_IOU - end_COM) / nframes * 1000),\ '{:6d} segmented neurons.'.format(len(times_piecedneurons_1))) # Merge neurons with high consume ratio. piecedneurons, times_piecedneurons = \ piece_neurons_consume(piecedneurons_1, avgArea, thresh_mask, thresh_consume, times_piecedneurons_1) end_consume = time.time() if display: print('{:25s}: Used {:9.6f} s, {:9.6f} ms/frame, '\ .format('piece_neurons_consume', end_consume - end_IOU, (end_consume - end_IOU) / nframes * 1000),\ '{:6d} segmented neurons.'.format(len(times_piecedneurons))) masks_final_2 = piecedneurons times_final = [np.unique(x) for x in times_piecedneurons] # Refine neurons using consecutive occurence requirement start = time.time() Masks_2, times_cons = refine_seperate(masks_final_2, times_final, cons, thresh_mask) end_all = time.time() if display: print('{:25s}: Used {:9.6f} s, {:9.6f} ms/frame, '\ .format('refine_seperate', end_all - start, (end_all - start) / nframes * 1000),\ '{:6d} segmented neurons.'.format(len(times_final))) print('{:25s}: Used {:9.6f} s, {:9.6f} ms/frame, '\ .format('Total time', end_all - start_all, (end_all - start_all) / nframes * 1000),\ '{:6d} segmented neurons.'.format(len(times_final))) return Masks_2, times_cons