def parameter_optimization_pipeline(file_CNN, network_input, dims, \ Params_set, filename_GT, batch_size_eval=1, useWT=False, useMP=True, p=None): '''The complete parameter optimization pipeline for one video and one CNN model. It first infers the probablity map of every frame in "network_input" using the trained CNN model in "file_CNN", then calculates the recall, precision, and F1 over all parameter combinations from "Params_set" by compairing with the GT labels in "filename_GT". Inputs: file_CNN (str): The path of the trained CNN model. Must be a ".h5" file. network_input (3D numpy.ndarray of float32, shape = (T,Lx,Ly)): the SNR video obtained after pre-processing. dims (tuplel of int, shape = (2,)): lateral dimension of the raw video. Params_set (dict): Ranges of post-processing parameters to optimize over. Params_set['list_minArea']: (list) Range of minimum area of a valid neuron mask (unit: pixels). Params_set['list_avgArea']: (list) Range of typical neuron area (unit: pixels). Params_set['list_thresh_pmap']: (list) Range of probablity threshold. Params_set['thresh_mask']: (float) Threashold to binarize the real-number mask. Params_set['thresh_COM0']: (float) Threshold of COM distance (unit: pixels) used for the first COM-based merging. Params_set['list_thresh_COM']: (list) Range of threshold of COM distance (unit: pixels) used for the second COM-based merging. Params_set['list_thresh_IOU']: (list) Range of threshold of IOU used for merging neurons. Params_set['thresh_consume']: (float) Threshold of consume ratio used for merging neurons. Params_set['list_cons']: (list) Range of minimum number of consecutive frames that a neuron should be active for. filename_GT (str): file name of the GT masks. The file must be a ".mat" file, with dataset "GTMasks" being the 2D sparse matrix (shape = (Ly0,Lx0,n) when saved in MATLAB). batch_size_eval (int, default to 1): batch size of CNN inference. useWT (bool, default to False): Indicator of whether watershed is used. useMP (bool, defaut to True): indicator of whether multiprocessing is used to speed up. p (multiprocessing.Pool, default to None): Outputs: list_Recall (6D numpy.array of float): Recall for all paramter combinations. list_Precision (6D numpy.array of float): Precision for all paramter combinations. list_F1 (6D numpy.array of float): F1 for all paramter combinations. For these outputs, the orders of the tunable parameters are: "minArea", "avgArea", "thresh_pmap", "thresh_COM", "thresh_IOU", "cons" ''' (Lx, Ly) = dims # load CNN model fff = get_shallow_unet() fff.load_weights(file_CNN) # CNN inference start_test = time.time() prob_map = fff.predict(network_input, batch_size=batch_size_eval) finish_test = time.time() Time_frame = (finish_test-start_test)/network_input.shape[0]*1000 print('Average infrence time {} ms/frame'.format(Time_frame)) # convert the output probability map from float to uint8 to speed up future parameter optimization prob_map = prob_map.squeeze(axis=-1)[:,:Lx,:Ly] pmaps = np.zeros(prob_map.shape, dtype='uint8') fastuint(prob_map, pmaps) del prob_map, fff # calculate the recall, precision, and F1 when different post-processing hyper-parameters are used. list_Recall, list_Precision, list_F1 = parameter_optimization(pmaps, Params_set, filename_GT, useMP=useMP, useWT=useWT, p=p) return list_Recall, list_Precision, list_F1
def train_CNN(dir_img, dir_mask, file_CNN, list_Exp_ID_train, list_Exp_ID_val, \ BATCH_SIZE, NO_OF_EPOCHS, num_train_per, num_total, dims, Params_loss=None, exist_model=None): '''Train a CNN model using SNR images in "dir_img" and the corresponding temporal masks in "dir_mask" identified for each video in "list_Exp_ID_train" using tensorflow generater formalism. The output are the trained CNN model saved in "file_CNN" and "results" containing loss. Inputs: dir_img (str): The folder containing the network_input (SNR images). Each file must be a ".h5" file, with dataset "network_input" being the SNR video (shape = (T,Lx,Ly)). dir_mask (str): The folder containing the temporal masks. Each file must be a ".h5" file, with dataset "temporal_masks" being the temporal masks (shape = (T,Lx,Ly)). file_CNN (str): The path to save the trained CNN model. list_Exp_ID_train (list of str): The list of file names of the training video(s). list_Exp_ID_val (list of str, default to None): The list of file names of the validation video(s). if list_Exp_ID_val is None, then no validation set is used BATCH_SIZE (int): batch size for CNN training. NO_OF_EPOCHS (int): number of epochs for CNN training. num_train_per (int): number of training images per video. num_total (int): total number of frames of a video (can be smaller than acutal number). dims (tuplel of int, shape = (2,)): lateral dimension of the video. Params_loss(dict, default to None): parameters of the loss function "total_loss" Params_loss['DL'](float): Coefficient of dice loss in the total loss Params_loss['BCE'](float): Coefficient of binary cross entropy in the total loss Params_loss['FL'](float): Coefficient of focal loss in the total loss Params_loss['gamma'] (float): first parameter of focal loss Params_loss['alpha'] (float): second parameter of focal loss exist_model (str, default to None): the path of existing model for transfer learning Outputs: results: the training results containing the loss information. In addition, the trained CNN model is saved in "file_CNN" as ".h5" files. ''' (rows, cols) = dims nvideo_train = len(list_Exp_ID_train) # Number of training videos # set how to choose training images train_every = max(1, num_total // num_train_per) start_frame_train = random.randint(0, train_every - 1) NO_OF_TRAINING_IMAGES = num_train_per * nvideo_train if list_Exp_ID_val is not None: # set how to choose validation images nvideo_val = len(list_Exp_ID_val) # Number of validation videos # the total number of validation images is about 1/9 of the traning images num_val_per = int((num_train_per * nvideo_train / nvideo_val) // 9) num_val_per = min(num_val_per, num_total) val_every = num_total // num_val_per start_frame_val = random.randint(0, val_every - 1) NO_OF_VAL_IMAGES = num_val_per * nvideo_val # %% Load traiming images and masks from h5 files # training images train_imgs = np.zeros((num_train_per * nvideo_train, rows, cols), dtype='float32') # temporal masks for training images train_masks = np.zeros((num_train_per * nvideo_train, rows, cols), dtype='uint8') if list_Exp_ID_val is not None: # validation images val_imgs = np.zeros((num_val_per * nvideo_val, rows, cols), dtype='float32') # temporal masks for validation images val_masks = np.zeros((num_val_per * nvideo_val, rows, cols), dtype='uint8') print('Loading training images and masks.') # Select training images: for each video, start from frame "start_frame", # select a frame every "train_every" frames, totally "train_val_per" frames for cnt, Exp_ID in enumerate(list_Exp_ID_train): h5_img = h5py.File(os.path.join(dir_img, Exp_ID + '.h5'), 'r') h5_mask = h5py.File(os.path.join(dir_mask, Exp_ID + '.h5'), 'r') num_frame = h5_img['network_input'].shape[0] if num_frame >= num_train_per: train_imgs[cnt*num_train_per:(cnt+1)*num_train_per,:,:] \ = np.array(h5_img['network_input'][start_frame_train:train_every*num_train_per:train_every]) train_masks[cnt*num_train_per:(cnt+1)*num_train_per,:,:] \ = np.array(h5_mask['temporal_masks'][start_frame_train:train_every*num_train_per:train_every]) else: train_imgs = np.array(h5_img['network_input']) train_masks = np.array(h5_mask['temporal_masks']) h5_img.close() h5_mask.close() if list_Exp_ID_val is not None: # Select validation images: for each video, start from frame "start_frame", # select a frame every "val_every" frames, totally "num_val_per" frames for cnt, Exp_ID in enumerate(list_Exp_ID_val): h5_img = h5py.File(os.path.join(dir_img, Exp_ID + '.h5'), 'r') val_imgs[cnt*num_val_per:(cnt+1)*num_val_per,:,:] \ = np.array(h5_img['network_input'][start_frame_val:val_every*num_val_per:val_every]) h5_img.close() h5_mask = h5py.File(os.path.join(dir_mask, Exp_ID + '.h5'), 'r') val_masks[cnt*num_val_per:(cnt+1)*num_val_per,:,:] \ = np.array(h5_mask['temporal_masks'][start_frame_val:val_every*num_val_per:val_every]) h5_mask.close() # generater for training and validation images and masks train_gen = data_gen(train_imgs, train_masks, batch_size=BATCH_SIZE, flips=True, rotate=True) if list_Exp_ID_val is not None: val_gen = data_gen(val_imgs, val_masks, batch_size=BATCH_SIZE, flips=False, rotate=False) if list_Exp_ID_val is None: val_gen = None NO_OF_VAL_IMAGES = 0 fff = get_shallow_unet(size=None, Params_loss=Params_loss) # The alternative line has more options to choose # fff = get_shallow_unet_more(size=None, n_depth=3, n_channel=4, skip=[1], activation='elu', Params_loss=Params_loss) if exist_model is not None: fff.load_weights(exist_model) class LossAndErrorPrintingCallback(tf.keras.callbacks.Callback): def on_epoch_end(self, epoch, logs=None): print('\n\nThe average loss for epoch {} is {:7.4f}.'.format( epoch, logs['loss'])) # train CNN results = fff.fit_generator( train_gen, epochs=NO_OF_EPOCHS, steps_per_epoch=(NO_OF_TRAINING_IMAGES // BATCH_SIZE), validation_data=val_gen, validation_steps=(NO_OF_VAL_IMAGES // BATCH_SIZE), verbose=1, callbacks=[LossAndErrorPrintingCallback()]) # save trained CNN model fff.save_weights(file_CNN) return results
def suns_online_track(filename_video, filename_CNN, Params_pre, Params_post, dims, \ frames_init, merge_every, batch_size_init=1, useSF=True, useTF=True, useSNR=True, \ med_subtract=False, update_baseline=False, \ useWT=False, prealloc=True, display=True, useMP=True, p=None): '''The complete SUNS online procedure with tracking. It uses the trained CNN model from "filename_CNN" and the optimized hyper-parameters in "Params_post" to process the video "Exp_ID" in "dir_video" Inputs: filename_video (str): The path of the file of the input raw video. The file must be a ".h5" file, with dataset "mov" being the input video (shape = (T0,Lx0,Ly0)). filename_CNN (str): The path of the trained CNN model. Params_pre (dict): Parameters for pre-processing. Params_pre['gauss_filt_size'] (float): The standard deviation of the spatial Gaussian filter in pixels Params_pre['Poisson_filt'] (1D numpy.ndarray of float32): The temporal filter kernel Params_pre['num_median_approx'] (int): Number of frames used to compute the median and median-based standard deviation Params_pre['nn'] (int): Number of frames at the beginning of the video to be processed. The remaining video is not considered a part of the input video. Params_post (dict): Parameters for post-processing. Params_post['minArea']: Minimum area of a valid neuron mask (unit: pixels). Params_post['avgArea']: The typical neuron area (unit: pixels). Params_post['thresh_pmap']: The probablity threshold. Values higher than thresh_pmap are active pixels. It is stored in uint8, so it should be converted to float32 before using. Params_post['thresh_mask']: Threashold to binarize the real-number mask. Params_post['thresh_COM0']: Threshold of COM distance (unit: pixels) used for the first COM-based merging. Params_post['thresh_COM']: Threshold of COM distance (unit: pixels) used for the second COM-based merging. Params_post['thresh_IOU']: Threshold of IOU used for merging neurons. Params_post['thresh_consume']: Threshold of consume ratio used for merging neurons. Params_post['cons']: Minimum number of consecutive frames that a neuron should be active for. dims (tuplel of int, shape = (2,)): lateral dimension of the raw video. frames_init (int): Number of frames used for initialization. merge_every (int): SUNS online merge the newly segmented frames every "merge_every" frames. batch_size_init (int, default to 1): batch size of CNN inference for initialization frames. useSF (bool, default to True): True if spatial filtering is used. useTF (bool, default to True): True if temporal filtering is used. useSNR (bool, default to True): True if pixel-by-pixel SNR normalization filtering is used. med_subtract (bool, default to False): True if the spatial median of every frame is subtracted before temporal filtering. Can only be used when spatial filtering is not used. update_baseline (bool, default to False): True if the median and median-based std is updated every "frames_init" frames. useWT (bool, default to False): Indicator of whether watershed is used. prealloc (bool, default to True): True if pre-allocate memory space for large variables. Achieve faster speed at the cost of higher memory occupation. display (bool, default to True): Indicator of whether to show intermediate information useMP (bool, defaut to True): indicator of whether multiprocessing is used to speed up. p (multiprocessing.Pool, default to None): Outputs: Masks (3D numpy.ndarray of bool, shape = (n,Lx0,Ly0)): the final segmented masks. Masks_2 (scipy.csr_matrix of bool, shape = (n,Lx0*Ly0)): the final segmented masks in the form of sparse matrix. time_total (list of float, shape = (3,)): the total time spent for initalization, online processing, and total processing time_frame (list of float, shape = (3,)): the average time spent on every frame for initalization, online processing, and total processing list_time_per (1D numpy.array of float): Time (s) spend on every frame during online processing ''' if display: start = time.time() (Lx, Ly) = dims # zero-pad the lateral dimensions to multiples of 8, suitable for CNN rowspad = math.ceil(Lx/8)*8 colspad = math.ceil(Ly/8)*8 dimspad = (rowspad, colspad) Poisson_filt = Params_pre['Poisson_filt'] gauss_filt_size = Params_pre['gauss_filt_size'] nn = Params_pre['nn'] leng_tf = Poisson_filt.size leng_past = 2*leng_tf # number of past frames stored for temporal filtering list_time_per = np.zeros(nn) # Load CNN model fff = get_shallow_unet() fff.load_weights(filename_CNN) # run CNN inference once to warm up init_imgs = np.zeros((batch_size_init, rowspad, colspad, 1), dtype='float32') init_masks = np.zeros((batch_size_init, rowspad, colspad, 1), dtype='uint8') fff.evaluate(init_imgs, init_masks, batch_size=batch_size_init) del init_imgs, init_masks # load optimal post-processing parameters minArea = Params_post['minArea'] avgArea = Params_post['avgArea'] # thresh_pmap = Params_post['thresh_pmap'] thresh_mask = Params_post['thresh_mask'] # thresh_COM0 = Params_post['thresh_COM0'] # thresh_COM = Params_post['thresh_COM'] thresh_IOU = Params_post['thresh_IOU'] thresh_consume = Params_post['thresh_consume'] # cons = Params_post['cons'] # thresh_pmap_float = (Params_post['thresh_pmap']+1.5)/256 thresh_pmap_float = (Params_post['thresh_pmap']+1)/256 # for published version # Spatial filtering preparation if useSF==True: # lateral dimensions slightly larger than the raw video but faster for FFT rows1 = cv2.getOptimalDFTSize(rowspad) cols1 = cv2.getOptimalDFTSize(colspad) # if the learned 2D and 3D wisdom files have been saved, load them. # Otherwise, learn wisdom later Length_data2=str((rows1, cols1)) cc2 = load_wisdom_txt('wisdom\\'+Length_data2) Length_data3=str((frames_init, rows1, cols1)) cc3 = load_wisdom_txt('wisdom\\'+Length_data3) if cc3: pyfftw.import_wisdom(cc3) # mask for spatial filter mask2 = plan_mask2(dims, (rows1, cols1), gauss_filt_size) # FFT planning (bb, bf, fft_object_b, fft_object_c) = plan_fft(frames_init, (rows1, cols1), prealloc) else: (mask2, bf, fft_object_b, fft_object_c) = (None, None, None, None) bb=np.zeros((frames_init, rowspad, colspad), dtype='float32') # Temporal filtering preparation frames_initf = frames_init - leng_tf + 1 if useTF==True: if prealloc: # past frames stored for temporal filtering past_frames = np.ones((leng_past, rowspad, colspad), dtype='float32') else: past_frames = np.zeros((leng_past, rowspad, colspad), dtype='float32') else: past_frames = None if prealloc: # Pre-allocate memory for some future variables med_frame2 = np.ones((rowspad, colspad, 2), dtype='float32') video_input = np.ones((frames_initf, rowspad, colspad), dtype='float32') pmaps_b_init = np.ones((frames_initf, Lx, Ly), dtype='uint8') frame_SNR = np.ones(dimspad, dtype='float32') pmaps_b = np.ones(dims, dtype='uint8') if update_baseline: video_tf_past = np.ones((frames_init, rowspad, colspad), dtype='float32') else: med_frame2 = np.zeros((rowspad, colspad, 2), dtype='float32') video_input = np.zeros((frames_initf, rowspad, colspad), dtype='float32') pmaps_b_init = np.zeros((frames_initf, Lx, Ly), dtype='uint8') frame_SNR = np.zeros(dimspad, dtype='float32') pmaps_b = np.zeros(dims, dtype='uint8') if update_baseline: video_tf_past = np.zeros((frames_init, rowspad, colspad), dtype='float32') if display: time_init = time.time() print('Parameter initialization time: {} s'.format(time_init-start)) if update_baseline: # Set the parameters for median update distri_update = False if useSF: # If spatial filtering is used, the padded rows and columns are nonzero Lu2 = colspad else: # If spatial filtering is not used, we can ignore the padded rows and columns Lu2 = Ly Lu1 = int(np.round(frames_initf / Lu2)) px_update = int(np.ceil(rowspad / Lu1)) Lu = Lu1 * Lu2 start_update = frames_initf + frames_init med_frame2_update = np.ones((px_update, 1, 2), dtype='float32') # %% Load raw video h5_img = h5py.File(filename_video, 'r') video_raw = np.array(h5_img['mov']) h5_img.close() nframes = video_raw.shape[0] nframesf = nframes - leng_tf + 1 bb[:, :Lx, :Ly] = video_raw[:frames_init] if display: time_load = time.time() print('Load data: {} s'.format(time_load-time_init)) # %% Actual processing starts after the video is loaded into memory # Initialization using the first "frames_init" frames print('Initialization of algorithms using the first {} frames'.format(frames_init)) if display: start_init = time.time() med_frame3, segs_all, recent_frames = init_online( bb, dims, video_input, pmaps_b_init, fff, thresh_pmap_float, Params_post, \ med_frame2, mask2, bf, fft_object_b, fft_object_c, Poisson_filt, \ useSF=useSF, useTF=useTF, useSNR=useSNR, med_subtract=med_subtract, \ useWT=useWT, batch_size_init=batch_size_init, p=p) if useTF==True: past_frames[:leng_tf] = recent_frames tuple_temp = merge_complete(segs_all[:frames_initf], dims, Params_post) # Initialize Online track variables (Masksb_temp, masks_temp, times_temp, area_temp, have_cons_temp) = tuple_temp # list of previously found neurons that satisfy consecutive frame requirement Masks_cons = select_cons(tuple_temp) # sparse matrix of previously found neurons that satisfy consecutive frame requirement Masks_cons_2D = sparse.vstack(Masks_cons) list_Masks_cons_2D = [Masks_cons_2D] # indices of previously found neurons that satisfy consecutive frame requirement ind_cons = have_cons_temp.nonzero()[0] segs0 = segs_all[0] # segs of initialization frames # segs if no neurons are found segs_empty = (segs0[0][0:0], segs0[1][0:0], segs0[2][0:0], segs0[3][0:0]) # Number of previously found neurons that satisfy consecutive frame requirement N1 = len(Masks_cons) # list of "segs" for neurons that are not previously found list_segs_new = [] # list of newly segmented masks for old neurons (segmented in previous frames) list_masks_old = [[] for _ in range(N1)] # list of the newly active indices of frames of old neurons times_active_old = [[] for _ in range(N1)] # True if the old neurons are active in the previous frame active_old_previous = np.zeros(N1, dtype='bool') if display: end_init = time.time() time_init = end_init-start_init time_frame_init = time_init/(frames_initf)*1000 print('Initialization time: {:6f} s, {:6f} ms/frame'.format(time_init, time_frame_init)) if display: start_online = time.time() # Spatial filtering preparation for online processing. # Attention: this part counts to the total time if useSF: if cc2: pyfftw.import_wisdom(cc2) (bb, bf, fft_object_b, fft_object_c) = plan_fft2((rows1, cols1)) else: (bf, fft_object_b, fft_object_c) = (None, None, None) bb=np.zeros(dimspad, dtype='float32') if update_baseline: med_frame3_temp = np.zeros_like(med_frame3) print('Start frame by frame processing') # %% Online processing for the following frames current_frame = leng_tf+1 t_merge = frames_initf list_active_old = [] for t in range(frames_initf, nframesf): if display: start_frame = time.time() # load the current frame bb[:Lx, :Ly] = video_raw[t+leng_tf-1] bb[Lx:, :] = 0 bb[:, Ly:] = 0 # PreProcessing frame_SNR, frame_tf = preprocess_online(bb, dimspad, med_frame3, frame_SNR, \ past_frames[current_frame-leng_tf:current_frame], mask2, bf, fft_object_b, fft_object_c, \ Poisson_filt, useSF=useSF, useTF=useTF, useSNR=useSNR, \ med_subtract=med_subtract, update_baseline=update_baseline) if update_baseline: t_past = (t-frames_initf) % frames_init video_tf_past[t_past] = frame_tf if distri_update: # update median and median-based standard deviation distributedly, column by column nu = (t-start_update) % Lu nu1 = nu // Lu1 nu2 = nu % Lu1 med_frame3_temp[:, nu2*px_update:(nu2+1)*px_update, nu1:nu1+1] = median_calculation( video_tf_past_fix[:, nu2*px_update:(nu2+1)*px_update, nu1:nu1+1], \ med_frame2_update, (px_update,1), 1, display=False) if nu == Lu-1: med_frame3 = med_frame3_temp.copy() (video_tf_past_fix, video_tf_past) = (video_tf_past, video_tf_past_fix) elif t >= start_update-1: distri_update = True (video_tf_past_fix, video_tf_past) = (video_tf_past, video_tf_past_fix) # CNN inference frame_prob = CNN_online(frame_SNR, fff, dims) # first step of post-processing segs = separate_neuron_online(frame_prob, pmaps_b, thresh_pmap_float, minArea, avgArea, useWT) active_old = np.zeros(N1, dtype='bool') # True if the old neurons are active in the current frame masks_t, neuronstate_t, cents_t, areas_t = segs N2 = neuronstate_t.size if N2: # Try to merge the new masks to old neurons new_found = np.zeros(N2, dtype='bool') for n2 in range(N2): masks_t2 = masks_t[n2] cents_t2 = np.round(cents_t[n2,1]) * Ly + np.round(cents_t[n2,0]) # If a new masks belongs to an old neuron, the COM of the new mask must be inside the old neuron area. # Select possible old neurons that the new mask can merge to possible_masks1 = Masks_cons_2D[:,cents_t2].nonzero()[0] IOUs = np.zeros(len(possible_masks1)) areas_t2 = areas_t[n2] for (ind,n1) in enumerate(possible_masks1): # Calculate IoU and consume ratio to determine merged neurons area_i = Masks_cons[n1].multiply(masks_t2).nnz area_temp1 = area_temp[n1] area_u = area_temp1 + areas_t2 - area_i IOU = area_i / area_u consume = area_i / min(area_temp1, areas_t2) contain = (IOU >= thresh_IOU) or (consume >= thresh_consume) if contain: # merging criterion satisfied IOUs[ind] = IOU num_contains = IOUs.nonzero()[0].size if num_contains: # The new mask can merge to one of the old neurons. # If there are multiple candicates, choose the one with the highest IoU belongs = possible_masks1[IOUs.argmax()] # merge the mask and active frame index list_masks_old[belongs].append(masks_t2) times_active_old[belongs].append(t + frames_initf) # This old neurons is active in the current frame active_old[belongs] = True else: # The new mask can not merge to any old neuron. new_found[n2] = True if np.any(new_found): # There are some new masks that can not merge to old neurons segs_new = (masks_t[new_found], neuronstate_t[new_found], cents_t[new_found], areas_t[new_found]) else: # All masks already merged to old neurons segs_new = segs_empty else: # No neurons fould in the current frame segs_new = segs list_segs_new.append(segs_new) if (t + 1 - t_merge) != merge_every or t == (nframesf-1): # Update the old neurons with new appearances in the current frame. if t < (nframesf-1): # True if the neurons are active in the previous frame but not active in the current frame inactive = np.logical_and(active_old_previous, np.logical_not(active_old)).nonzero()[0] else: # last frame # All active neurons should be updated, so they are treated as inactive in the next frame inactive = active_old_previous.nonzero()[0] # Update the indicators of the previous frame using the current frame active_old_previous = active_old.copy() for n1 in inactive: # merge new active frames to existing active frames for already found neurons # n1 is the index in the old neurons that satisfy consecutive frame requirement. # n10 is the index in all old neurons. n10 = ind_cons[n1] # Add all the new masks to the overall real-number masks mask_update = masks_temp[n10] + sum(list_masks_old[n1]) masks_temp[n10] = mask_update # Add indices of active frames times_add = np.unique(np.array(times_active_old[n1])) times_temp[n10] = np.hstack([times_temp[n10], times_add]) # reset lists used to store the information from new frames related to old neurons list_masks_old[n1] = [] times_active_old[n1] = [] # update the binary masks and areas Maskb_update = mask_update >= mask_update.max() * thresh_mask Masksb_temp[n10] = Maskb_update Masks_cons[n1] = Maskb_update area_temp[n10] = Maskb_update.nnz if inactive.size: Masks_cons_2D = sparse.vstack(Masks_cons) if (t + 1 - t_merge) == merge_every or t == (nframesf-1): if t < (nframesf-1): # delay merging new frame to next frame by assuming all the neurons active in the previous frame # are still active in the current frame, to reserve merging time for new neurons active_old_previous = np.logical_or(active_old_previous, active_old) # merge new neurons with old masks that do not satisfy consecutive frame requirement tuple_temp = (Masksb_temp, masks_temp, times_temp, area_temp, have_cons_temp) # merge the remaining new masks from the most recent "merge_every" frames tuple_add = merge_complete(list_segs_new, dims, Params_post) (Masksb_add, masks_add, times_add, area_add, have_cons_add) = tuple_add # update the indices of active frames times_add = [x + merge_every for x in times_add] tuple_add = (Masksb_add, masks_add, times_add, area_add, have_cons_add) # merge the remaining new masks with the existing masks that do not satisfy consecutive frame requirement tuple_temp = merge_2_nocons(tuple_temp, tuple_add, dims, Params_post) (Masksb_temp, masks_temp, times_temp, area_temp, have_cons_temp) = tuple_temp # Update the indices of old neurons that satisfy consecutive frame requirement ind_cons_new = have_cons_temp.nonzero()[0] for (ind,ind_cons_0) in enumerate(ind_cons_new): if ind_cons_0 not in ind_cons: # update lists used to store the information from new frames related to old neurons if ind_cons_0 > ind_cons.max(): list_masks_old.append([]) times_active_old.append([]) else: list_masks_old.insert(ind, []) times_active_old.insert(ind, []) # Update the list of previously found neurons that satisfy consecutive frame requirement Masks_cons = select_cons(tuple_temp) Masks_cons_2D = sparse.vstack(Masks_cons) N1 = len(Masks_cons) list_segs_new = [] # Update whether the old neurons are active in the previous frame active_old_previous = np.zeros_like(have_cons_temp) active_old_previous[ind_cons] = active_old active_old_previous = active_old_previous[ind_cons_new] ind_cons = ind_cons_new t_merge += merge_every current_frame +=1 # Update the stored latest frames when it runs out: move them "leng_tf" ahead if current_frame > leng_past: current_frame = leng_tf+1 past_frames[:leng_tf] = past_frames[-leng_tf:] if display: end_frame = time.time() list_time_per[t] = end_frame - start_frame if t % 1000 == 0: print('{} frames have been processed'.format(t)) list_Masks_cons_2D.append(Masks_cons_2D) list_active_old.append(active_old) Masks_cons = select_cons(tuple_temp) # final result. Masks_2 is a 2D sparse matrix of the segmented neurons if len(Masks_cons): Masks_2 = sparse.vstack(Masks_cons) else: Masks_2 = sparse.csc_matrix((0,dims[0]*dims[1])) list_Masks_cons_2D.append(Masks_2) if display: end_online = time.time() time_online = end_online-start_online time_frame_online = time_online/(nframesf-frames_initf)*1000 print('Online time: {:6f} s, {:6f} ms/frame'.format(time_online, time_frame_online)) # Save total processing time, and average processing time per frame if display: end_final = time.time() time_all = end_final-start_init time_frame_all = time_all/nframes*1000 print('Total time: {:6f} s, {:6f} ms/frame'.format(time_all, time_frame_all)) time_total = np.array([time_init, time_online, time_all]) time_frame = np.array([time_frame_init, time_frame_online, time_frame_all]) else: time_total = np.zeros((3,)) time_frame = np.zeros((3,)) # convert to a 3D array of the segmented neurons Masks = np.reshape(Masks_2.toarray(), (Masks_2.shape[0], Lx, Ly)).astype('bool') return Masks, Masks_2, time_total, time_frame, list_time_per, list_Masks_cons_2D, list_active_old
def suns_batch(dir_video, Exp_ID, filename_CNN, Params_pre, Params_post, dims, \ batch_size_eval=1, useSF=True, useTF=True, useSNR=True, med_subtract=False, \ useWT=False, prealloc=True, display=True, useMP=True, p=None): '''The complete SUNS batch procedure. It uses the trained CNN model from "filename_CNN" and the optimized hyper-parameters in "Params_post" to process the video "Exp_ID" in "dir_video" Inputs: dir_video (str): The folder containing the input video. Each file must be a ".h5" file, with dataset "mov" being the input video (shape = (T0,Lx0,Ly0)). Exp_ID (str): The filer name of the input raw video. filename_CNN (str): The path of the trained CNN model. Params_pre (dict): Parameters for pre-processing. Params_pre['gauss_filt_size'] (float): The standard deviation of the spatial Gaussian filter in pixels Params_pre['Poisson_filt'] (1D numpy.ndarray of float32): The temporal filter kernel Params_pre['num_median_approx'] (int): Number of frames used to compute the median and median-based standard deviation Params_pre['nn'] (int): Number of frames at the beginning of the video to be processed. The remaining video is not considered a part of the input video. Params_post (dict): Parameters for post-processing. Params_post['minArea']: Minimum area of a valid neuron mask (unit: pixels). Params_post['avgArea']: The typical neuron area (unit: pixels). Params_post['thresh_pmap']: The probablity threshold. Values higher than thresh_pmap are active pixels. It is stored in uint8, so it should be converted to float32 before using. Params_post['thresh_mask']: Threashold to binarize the real-number mask. Params_post['thresh_COM0']: Threshold of COM distance (unit: pixels) used for the first COM-based merging. Params_post['thresh_COM']: Threshold of COM distance (unit: pixels) used for the second COM-based merging. Params_post['thresh_IOU']: Threshold of IOU used for merging neurons. Params_post['thresh_consume']: Threshold of consume ratio used for merging neurons. Params_post['cons']: Minimum number of consecutive frames that a neuron should be active for. dims (tuplel of int, shape = (2,)): lateral dimension of the raw video. batch_size_eval (int, default to 1): batch size of CNN inference. useSF (bool, default to True): True if spatial filtering is used. useTF (bool, default to True): True if temporal filtering is used. useSNR (bool, default to True): True if pixel-by-pixel SNR normalization filtering is used. med_subtract (bool, default to False): True if the spatial median of every frame is subtracted before temporal filtering. Can only be used when spatial filtering is not used. useWT (bool, default to False): Indicator of whether watershed is used. prealloc (bool, default to True): True if pre-allocate memory space for large variables. Achieve faster speed at the cost of higher memory occupation. display (bool, default to True): Indicator of whether to show intermediate information useMP (bool, defaut to True): indicator of whether multiprocessing is used to speed up. p (multiprocessing.Pool, default to None): Outputs: Masks (3D numpy.ndarray of bool, shape = (n,Lx0,Ly0)): the final segmented masks. Masks_2 (scipy.csr_matrix of bool, shape = (n,Lx0*Ly0)): the final segmented masks in the form of sparse matrix. time_total (list of float, shape = (4,)): the total time spent for pre-processing, CNN inference, post-processing, and total processing time_frame (list of float, shape = (4,)): the average time spent on every frame for pre-processing, CNN inference, post-processing, and total processing ''' if display: start = time.time() (Lx, Ly) = dims rowspad = math.ceil(Lx/8)*8 colspad = math.ceil(Ly/8)*8 # load CNN model fff = get_shallow_unet() fff.load_weights(filename_CNN) # run CNN inference once to warm up init_imgs = np.zeros((batch_size_eval, rowspad, colspad, 1), dtype='float32') init_masks = np.zeros((batch_size_eval, rowspad, colspad, 1), dtype='uint8') fff.evaluate(init_imgs, init_masks, batch_size=batch_size_eval) del init_imgs, init_masks # thresh_pmap_float = (Params_post['thresh_pmap']+1.5)/256 thresh_pmap_float = (Params_post['thresh_pmap']+1)/256 # for published version if display: time_init = time.time() print('Initialization time: {} s'.format(time_init-start)) # %% Actual processing starts after the video is loaded into memory # which is in the middle of "preprocess_video", represented by the output "start" # pre-processing including loading data video_input, start = preprocess_video(dir_video, Exp_ID, Params_pre, \ useSF=useSF, useTF=useTF, useSNR=useSNR, med_subtract=med_subtract, prealloc=prealloc, display=display) nframes = video_input.shape[0] if display: end_pre = time.time() time_pre = end_pre-start time_frame_pre = time_pre/nframes*1000 print('Pre-Processing time: {:6f} s, {:6f} ms/frame'.format(time_pre, time_frame_pre)) # CNN inference video_input = np.expand_dims(video_input, axis=-1) prob_map = fff.predict(video_input, batch_size=batch_size_eval) if display: end_network = time.time() time_CNN = end_network-end_pre time_frame_CNN = time_CNN/nframes*1000 print('CNN Infrence time: {:6f} s, {:6f} ms/frame'.format(time_CNN, time_frame_CNN)) # post-processing prob_map = prob_map.squeeze()[:, :Lx, :Ly] print(Params_post) Params_post_copy = Params_post.copy() Params_post_copy['thresh_pmap'] = None # Avoid repeated thresholding in postprocessing pmaps_b = np.zeros(prob_map.shape, dtype='uint8') # threshold the probability map to binary activity fastthreshold(prob_map, pmaps_b, thresh_pmap_float) # the rest of post-processing. The result is a 2D sparse matrix of the segmented neurons Masks_2 = complete_segment(pmaps_b, Params_post_copy, display=display, p=p, useWT=useWT) if display: finish = time.time() time_post = finish-end_network time_frame_post = time_post/nframes*1000 print('Post-Processing time: {:6f} s, {:6f} ms/frame'.format(time_post, time_frame_post)) # convert to a 3D array of the segmented neurons Masks = np.reshape(Masks_2.toarray(), (Masks_2.shape[0], Lx, Ly)).astype('bool') # Save total processing time, and average processing time per frame if display: time_all = finish-start time_frame_all = time_all/nframes*1000 print('Total time: {:6f} s, {:6f} ms/frame'.format(time_all, time_frame_all)) time_total = np.array([time_pre, time_CNN, time_post, time_all]) time_frame = np.array([time_frame_pre, time_frame_CNN, time_frame_post, time_frame_all]) else: time_total = np.zeros((4,)) time_frame = np.zeros((4,)) return Masks, Masks_2, time_total, time_frame
def suns_online(filename_video, filename_CNN, Params_pre, Params_post, dims, \ frames_init, merge_every, batch_size_init=1, useSF=True, useTF=True, useSNR=True, \ med_subtract=False, update_baseline=False, \ useWT=False, show_intermediate=True, prealloc=True, display=True, useMP=True, p=None): '''The complete SUNS online procedure. It uses the trained CNN model from "filename_CNN" and the optimized hyper-parameters in "Params_post" to process the video "Exp_ID" in "dir_video" Inputs: filename_video (str): The path of the file of the input raw video. The file must be a ".h5" file, with dataset "mov" being the input video (shape = (T0,Lx0,Ly0)). filename_CNN (str): The path of the trained CNN model. Params_pre (dict): Parameters for pre-processing. Params_pre['gauss_filt_size'] (float): The standard deviation of the spatial Gaussian filter in pixels Params_pre['Poisson_filt'] (1D numpy.ndarray of float32): The temporal filter kernel Params_pre['num_median_approx'] (int): Number of frames used to compute the median and median-based standard deviation Params_pre['nn'] (int): Number of frames at the beginning of the video to be processed. The remaining video is not considered a part of the input video. Params_post (dict): Parameters for post-processing. Params_post['minArea']: Minimum area of a valid neuron mask (unit: pixels). Params_post['avgArea']: The typical neuron area (unit: pixels). Params_post['thresh_pmap']: The probablity threshold. Values higher than thresh_pmap are active pixels. It is stored in uint8, so it should be converted to float32 before using. Params_post['thresh_mask']: Threashold to binarize the real-number mask. Params_post['thresh_COM0']: Threshold of COM distance (unit: pixels) used for the first COM-based merging. Params_post['thresh_COM']: Threshold of COM distance (unit: pixels) used for the second COM-based merging. Params_post['thresh_IOU']: Threshold of IOU used for merging neurons. Params_post['thresh_consume']: Threshold of consume ratio used for merging neurons. Params_post['cons']: Minimum number of consecutive frames that a neuron should be active for. dims (tuplel of int, shape = (2,)): lateral dimension of the raw video. frames_init (int): Number of frames used for initialization. merge_every (int): SUNS online merge the newly segmented frames every "merge_every" frames. batch_size_init (int, default to 1): batch size of CNN inference for initialization frames. useSF (bool, default to True): True if spatial filtering is used. useTF (bool, default to True): True if temporal filtering is used. useSNR (bool, default to True): True if pixel-by-pixel SNR normalization filtering is used. med_subtract (bool, default to False): True if the spatial median of every frame is subtracted before temporal filtering. Can only be used when spatial filtering is not used. update_baseline (bool, default to False): True if the median and median-based std is updated every "frames_init" frames. useWT (bool, default to False): Indicator of whether watershed is used. show_intermediate (bool, default to True): Indicator of whether consecutive frame requirement is applied to screen neurons after every update. prealloc (bool, default to True): True if pre-allocate memory space for large variables. Achieve faster speed at the cost of higher memory occupation. display (bool, default to True): Indicator of whether to show intermediate information useMP (bool, defaut to True): indicator of whether multiprocessing is used to speed up. p (multiprocessing.Pool, default to None): Outputs: Masks (3D numpy.ndarray of bool, shape = (n,Lx0,Ly0)): the final segmented masks. Masks_2 (scipy.csr_matrix of bool, shape = (n,Lx0*Ly0)): the final segmented masks in the form of sparse matrix. time_total (list of float, shape = (3,)): the total time spent for initalization, online processing, and total processing time_frame (list of float, shape = (3,)): the average time spent on every frame for initalization, online processing, and total processing list_time_per (1D numpy.array of float): Time (s) spend on every frame during online processing ''' if display: start = time.time() (Lx, Ly) = dims # zero-pad the lateral dimensions to multiples of 8, suitable for CNN rowspad = math.ceil(Lx/8)*8 colspad = math.ceil(Ly/8)*8 dimspad = (rowspad, colspad) Poisson_filt = Params_pre['Poisson_filt'] gauss_filt_size = Params_pre['gauss_filt_size'] nn = Params_pre['nn'] leng_tf = Poisson_filt.size leng_past = 2*leng_tf # number of past frames stored for temporal filtering list_time_per = np.zeros(nn) # Load CNN model fff = get_shallow_unet() fff.load_weights(filename_CNN) # run CNN inference once to warm up init_imgs = np.zeros((batch_size_init, rowspad, colspad, 1), dtype='float32') init_masks = np.zeros((batch_size_init, rowspad, colspad, 1), dtype='uint8') fff.evaluate(init_imgs, init_masks, batch_size=batch_size_init) del init_imgs, init_masks # load optimal post-processing parameters minArea = Params_post['minArea'] avgArea = Params_post['avgArea'] # thresh_pmap = Params_post['thresh_pmap'] thresh_mask = Params_post['thresh_mask'] thresh_COM0 = Params_post['thresh_COM0'] thresh_COM = Params_post['thresh_COM'] thresh_IOU = Params_post['thresh_IOU'] thresh_consume = Params_post['thresh_consume'] cons = Params_post['cons'] # thresh_pmap_float = (Params_post['thresh_pmap']+1.5)/256 thresh_pmap_float = (Params_post['thresh_pmap']+1)/256 # for published version # Spatial filtering preparation if useSF==True: # lateral dimensions slightly larger than the raw video but faster for FFT rows1 = cv2.getOptimalDFTSize(rowspad) cols1 = cv2.getOptimalDFTSize(colspad) # if the learned 2D and 3D wisdom files have been saved, load them. # Otherwise, learn wisdom later Length_data2=str((rows1, cols1)) cc2 = load_wisdom_txt('wisdom\\'+Length_data2) Length_data3=str((frames_init, rows1, cols1)) cc3 = load_wisdom_txt('wisdom\\'+Length_data3) if cc3: pyfftw.import_wisdom(cc3) # mask for spatial filter mask2 = plan_mask2(dims, (rows1, cols1), gauss_filt_size) # FFT planning (bb, bf, fft_object_b, fft_object_c) = plan_fft(frames_init, (rows1, cols1), prealloc) else: (mask2, bf, fft_object_b, fft_object_c) = (None, None, None, None) bb=np.zeros((frames_init, rowspad, colspad), dtype='float32') # Temporal filtering preparation frames_initf = frames_init - leng_tf + 1 if useTF==True: if prealloc: # past frames stored for temporal filtering past_frames = np.ones((leng_past, rowspad, colspad), dtype='float32') else: past_frames = np.zeros((leng_past, rowspad, colspad), dtype='float32') else: past_frames = None if prealloc: # Pre-allocate memory for some future variables med_frame2 = np.ones((rowspad, colspad, 2), dtype='float32') video_input = np.ones((frames_initf, rowspad, colspad), dtype='float32') pmaps_b_init = np.ones((frames_initf, Lx, Ly), dtype='uint8') frame_SNR = np.ones(dimspad, dtype='float32') pmaps_b = np.ones(dims, dtype='uint8') if update_baseline: video_tf_past = np.ones((frames_init, rowspad, colspad), dtype='float32') video_tf_past_fix = np.zeros((frames_init, rowspad, colspad), dtype='float32') else: med_frame2 = np.zeros((rowspad, colspad, 2), dtype='float32') video_input = np.zeros((frames_initf, rowspad, colspad), dtype='float32') pmaps_b_init = np.zeros((frames_initf, Lx, Ly), dtype='uint8') frame_SNR = np.zeros(dimspad, dtype='float32') pmaps_b = np.zeros(dims, dtype='uint8') if update_baseline: video_tf_past = np.zeros((frames_init, rowspad, colspad), dtype='float32') video_tf_past_fix = np.zeros((frames_init, rowspad, colspad), dtype='float32') if display: time_init = time.time() print('Parameter initialization time: {} s'.format(time_init-start)) if update_baseline: # Set the parameters for median update distri_update = False if useSF: # If spatial filtering is used, the padded rows and columns are nonzero Lu2 = colspad else: # If spatial filtering is not used, we can ignore the padded rows and columns Lu2 = Ly Lu1 = int(np.round(frames_initf / Lu2)) px_update = int(np.ceil(rowspad / Lu1)) Lu = Lu1 * Lu2 start_update = frames_initf + frames_init med_frame2_update = np.ones((px_update, 1, 2), dtype='float32') # %% Load raw video h5_img = h5py.File(filename_video, 'r') video_raw = np.array(h5_img['mov']) h5_img.close() nframes = video_raw.shape[0] nframesf = nframes - leng_tf + 1 bb[:, :Lx, :Ly] = video_raw[:frames_init] if display: time_load = time.time() print('Load data: {} s'.format(time_load-time_init)) # %% Actual processing starts after the video is loaded into memory # Initialization using the first "frames_init" frames print('Initialization of algorithms using the first {} frames'.format(frames_init)) if display: start_init = time.time() med_frame3, segs_all, recent_frames = init_online( bb, dims, video_input, pmaps_b_init, fff, thresh_pmap_float, Params_post, \ med_frame2, mask2, bf, fft_object_b, fft_object_c, Poisson_filt, \ useSF=useSF, useTF=useTF, useSNR=useSNR, med_subtract=med_subtract, \ useWT=useWT, batch_size_init=batch_size_init, p=p) if useTF==True: past_frames[:leng_tf] = recent_frames tuple_temp = merge_complete(segs_all[:frames_initf], dims, Params_post) if show_intermediate: Masks_2 = select_cons(tuple_temp) if len(Masks_2): Masks_2 = sparse.vstack(Masks_2) else: Masks_2 = sparse.csc_matrix((0,dims[0]*dims[1])) list_Masks_2 = [Masks_2] if display: end_init = time.time() time_init = end_init-start_init time_frame_init = time_init/(frames_initf)*1000 print('Initialization time: {:6f} s, {:6f} ms/frame'.format(time_init, time_frame_init)) if display: start_online = time.time() # Spatial filtering preparation for online processing. # Attention: this part counts to the total time if useSF: if cc2: pyfftw.import_wisdom(cc2) (bb, bf, fft_object_b, fft_object_c) = plan_fft2((rows1, cols1)) else: (bf, fft_object_b, fft_object_c) = (None, None, None) bb=np.zeros(dimspad, dtype='float32') if update_baseline: med_frame3_temp = np.zeros_like(med_frame3) print('Start frame by frame processing') # %% Online processing for the following frames current_frame = leng_tf+1 t_merge = frames_initf for t in range(frames_initf,nframesf): if display: start_frame = time.time() # load the current frame bb[:Lx, :Ly] = video_raw[t+leng_tf-1] bb[Lx:, :] = 0 bb[:, Ly:] = 0 # PreProcessing frame_SNR, frame_tf = preprocess_online(bb, dimspad, med_frame3, frame_SNR, \ past_frames[current_frame-leng_tf:current_frame], mask2, bf, fft_object_b, \ fft_object_c, Poisson_filt, useSF=useSF, useTF=useTF, useSNR=useSNR, \ med_subtract=med_subtract, update_baseline=update_baseline) if update_baseline: t_past = (t-frames_initf) % frames_init video_tf_past[t_past] = frame_tf if distri_update: # update median and median-based standard deviation distributedly, column by column nu = (t-start_update) % Lu nu1 = nu // Lu1 nu2 = nu % Lu1 med_frame3_temp[:, nu2*px_update:(nu2+1)*px_update, nu1:nu1+1] = median_calculation( video_tf_past_fix[:, nu2*px_update:(nu2+1)*px_update, nu1:nu1+1], \ med_frame2_update, (px_update,1), 1, display=False) if nu == Lu-1: med_frame3 = med_frame3_temp.copy() (video_tf_past_fix, video_tf_past) = (video_tf_past, video_tf_past_fix) elif t >= start_update-1: distri_update = True (video_tf_past_fix, video_tf_past) = (video_tf_past, video_tf_past_fix) # CNN inference frame_prob = CNN_online(frame_SNR, fff, dims) # first step of post-processing segs = separate_neuron_online(frame_prob, pmaps_b, thresh_pmap_float, minArea, avgArea, useWT) segs_all.append(segs) # temporal merging 1: combine neurons with COM distance smaller than thresh_COM0 if ((t + 1 - t_merge) == merge_every) or (t==nframesf-1): # uniques, times_uniques = unique_neurons1_simp(segs_all[t_merge:], thresh_COM0) # minArea, totalmasks, neuronstate, COMs, areas, probmapID = segs_results(segs_all[t_merge:]) uniques, times_uniques = unique_neurons2_simp(totalmasks, neuronstate, COMs, \ areas, probmapID, minArea=0, thresh_COM0=thresh_COM0, useMP=useMP) # temporal merging 2: combine neurons with COM distance smaller than thresh_COM if ((t - 0 - t_merge) == merge_every) or (t==nframesf-1): if uniques.size: groupedneurons, times_groupedneurons = \ group_neurons(uniques, thresh_COM, thresh_mask, dims, times_uniques, useMP=useMP) # temporal merging 3: combine neurons with IoU larger than thresh_IOU if ((t - 1 - t_merge) == merge_every) or (t==nframesf-1): if uniques.size: piecedneurons_1, times_piecedneurons_1 = \ piece_neurons_IOU(groupedneurons, thresh_mask, thresh_IOU, times_groupedneurons) # temporal merging 4: combine neurons with conumse ratio larger than thresh_consume if ((t - 2 - t_merge) == merge_every) or (t==nframesf-1): if uniques.size: piecedneurons, times_piecedneurons = \ piece_neurons_consume(piecedneurons_1, avgArea, thresh_mask, thresh_consume, times_piecedneurons_1) # masks of new neurons masks_add = piecedneurons # indices of frames when the neurons are active times_add = [np.unique(x) + t_merge for x in times_piecedneurons] # Refine neurons using consecutive occurence if masks_add.size: # new real-number masks masks_add = [x for x in masks_add] # new binary masks Masksb_add = [(x >= x.max() * thresh_mask).astype('float') for x in masks_add] # areas of new masks area_add = np.array([x.nnz for x in Masksb_add]) # indicators of whether the new masks satisfy consecutive frame requirement have_cons_add = refine_seperate_cons_online(times_add, cons) else: Masksb_add = [] area_add = np.array([]) have_cons_add = np.array([]) else: # does not find any active neuron Masksb_add = [] masks_add = [] times_add = times_uniques area_add = np.array([]) have_cons_add = np.array([]) tuple_add = (Masksb_add, masks_add, times_add, area_add, have_cons_add) # temporal merging 5: merge newly found neurons within the recent "merge_every" frames with existing neurons if ((t - 3 - t_merge) == merge_every) or (t==nframesf-1): tuple_temp = merge_2(tuple_temp, tuple_add, dims, Params_post) t_merge += merge_every if show_intermediate: Masks_2 = select_cons(tuple_temp) if len(Masks_2): Masks_2 = sparse.vstack(Masks_2) else: Masks_2 = sparse.csc_matrix((0,dims[0]*dims[1])) list_Masks_2.append(Masks_2) current_frame +=1 # Update the stored latest frames when it runs out: move them "leng_tf" ahead if current_frame > leng_past: current_frame = leng_tf+1 past_frames[:leng_tf] = past_frames[-leng_tf:] if display: end_frame = time.time() list_time_per[t] = end_frame - start_frame if t % 1000 == 0: print('{} frames have been processed'.format(t)) if not show_intermediate: Masks_2 = select_cons(tuple_temp) # final result. Masks_2 is a 2D sparse matrix of the segmented neurons if len(Masks_2): Masks_2 = sparse.vstack(Masks_2) else: Masks_2 = sparse.csc_matrix((0,dims[0]*dims[1])) list_Masks_2.append(Masks_2) if display: end_online = time.time() time_online = end_online-start_online time_frame_online = time_online/(nframesf-frames_initf)*1000 print('Online time: {:6f} s, {:6f} ms/frame'.format(time_online, time_frame_online)) # Save total processing time, and average processing time per frame if display: end_final = time.time() time_all = end_final-start_init time_frame_all = time_all/nframes*1000 print('Total time: {:6f} s, {:6f} ms/frame'.format(time_all, time_frame_all)) time_total = np.array([time_init, time_online, time_all]) time_frame = np.array([time_frame_init, time_frame_online, time_frame_all]) else: time_total = np.zeros((3,)) time_frame = np.zeros((3,)) # convert to a 3D array of the segmented neurons Masks = np.reshape(Masks_2.toarray(), (Masks_2.shape[0], Lx, Ly)).astype('bool') return Masks, Masks_2, time_total, time_frame, list_time_per, list_Masks_2