Example #1
0
def parameter_optimization_pipeline(file_CNN, network_input, dims, \
        Params_set, filename_GT, batch_size_eval=1, useWT=False, useMP=True, p=None):
    '''The complete parameter optimization pipeline for one video and one CNN model.
        It first infers the probablity map of every frame in "network_input" using the trained CNN model in "file_CNN", 
        then calculates the recall, precision, and F1 over all parameter combinations from "Params_set"
        by compairing with the GT labels in "filename_GT". 

    Inputs: 
        file_CNN (str): The path of the trained CNN model. Must be a ".h5" file. 
        network_input (3D numpy.ndarray of float32, shape = (T,Lx,Ly)): 
            the SNR video obtained after pre-processing.
        dims (tuplel of int, shape = (2,)): lateral dimension of the raw video.
        Params_set (dict): Ranges of post-processing parameters to optimize over.
            Params_set['list_minArea']: (list) Range of minimum area of a valid neuron mask (unit: pixels).
            Params_set['list_avgArea']: (list) Range of  typical neuron area (unit: pixels).
            Params_set['list_thresh_pmap']: (list) Range of probablity threshold. 
            Params_set['thresh_mask']: (float) Threashold to binarize the real-number mask.
            Params_set['thresh_COM0']: (float) Threshold of COM distance (unit: pixels) used for the first COM-based merging. 
            Params_set['list_thresh_COM']: (list) Range of threshold of COM distance (unit: pixels) used for the second COM-based merging. 
            Params_set['list_thresh_IOU']: (list) Range of threshold of IOU used for merging neurons.
            Params_set['thresh_consume']: (float) Threshold of consume ratio used for merging neurons.
            Params_set['list_cons']: (list) Range of minimum number of consecutive frames that a neuron should be active for.
        filename_GT (str): file name of the GT masks. 
            The file must be a ".mat" file, with dataset "GTMasks" being the 2D sparse matrix 
            (shape = (Ly0,Lx0,n) when saved in MATLAB).
        batch_size_eval (int, default to 1): batch size of CNN inference.
        useWT (bool, default to False): Indicator of whether watershed is used. 
        useMP (bool, defaut to True): indicator of whether multiprocessing is used to speed up. 
        p (multiprocessing.Pool, default to None): 

    Outputs:
        list_Recall (6D numpy.array of float): Recall for all paramter combinations. 
        list_Precision (6D numpy.array of float): Precision for all paramter combinations. 
        list_F1 (6D numpy.array of float): F1 for all paramter combinations. 
            For these outputs, the orders of the tunable parameters are:
            "minArea", "avgArea", "thresh_pmap", "thresh_COM", "thresh_IOU", "cons"
    '''
    (Lx, Ly) = dims
    # load CNN model
    fff = get_shallow_unet()
    fff.load_weights(file_CNN)

    # CNN inference
    start_test = time.time()
    prob_map = fff.predict(network_input, batch_size=batch_size_eval)
    finish_test = time.time()
    Time_frame = (finish_test-start_test)/network_input.shape[0]*1000
    print('Average infrence time {} ms/frame'.format(Time_frame))

    # convert the output probability map from float to uint8 to speed up future parameter optimization
    prob_map = prob_map.squeeze(axis=-1)[:,:Lx,:Ly]
    pmaps = np.zeros(prob_map.shape, dtype='uint8')
    fastuint(prob_map, pmaps)
    del prob_map, fff

    # calculate the recall, precision, and F1 when different post-processing hyper-parameters are used.
    list_Recall, list_Precision, list_F1 = parameter_optimization(pmaps, Params_set, filename_GT, useMP=useMP, useWT=useWT, p=p)
    return list_Recall, list_Precision, list_F1
Example #2
0
def train_CNN(dir_img, dir_mask, file_CNN, list_Exp_ID_train, list_Exp_ID_val, \
    BATCH_SIZE, NO_OF_EPOCHS, num_train_per, num_total, dims, Params_loss=None, exist_model=None):
    '''Train a CNN model using SNR images in "dir_img" and the corresponding temporal masks in "dir_mask" 
        identified for each video in "list_Exp_ID_train" using tensorflow generater formalism.
        The output are the trained CNN model saved in "file_CNN" and "results" containing loss. 

    Inputs: 
        dir_img (str): The folder containing the network_input (SNR images). 
            Each file must be a ".h5" file, with dataset "network_input" being the SNR video (shape = (T,Lx,Ly)).
        dir_mask (str): The folder containing the temporal masks. 
            Each file must be a ".h5" file, with dataset "temporal_masks" being the temporal masks (shape = (T,Lx,Ly)).
        file_CNN (str): The path to save the trained CNN model.
        list_Exp_ID_train (list of str): The list of file names of the training video(s). 
        list_Exp_ID_val (list of str, default to None): The list of file names of the validation video(s). 
            if list_Exp_ID_val is None, then no validation set is used
        BATCH_SIZE (int): batch size for CNN training.
        NO_OF_EPOCHS (int): number of epochs for CNN training.
        num_train_per (int): number of training images per video.
        num_total (int): total number of frames of a video (can be smaller than acutal number).
        dims (tuplel of int, shape = (2,)): lateral dimension of the video.
        Params_loss(dict, default to None): parameters of the loss function "total_loss"
            Params_loss['DL'](float): Coefficient of dice loss in the total loss
            Params_loss['BCE'](float): Coefficient of binary cross entropy in the total loss
            Params_loss['FL'](float): Coefficient of focal loss in the total loss
            Params_loss['gamma'] (float): first parameter of focal loss
            Params_loss['alpha'] (float): second parameter of focal loss
        exist_model (str, default to None): the path of existing model for transfer learning 

    Outputs:
        results: the training results containing the loss information.
        In addition, the trained CNN model is saved in "file_CNN" as ".h5" files.
    '''
    (rows, cols) = dims
    nvideo_train = len(list_Exp_ID_train)  # Number of training videos
    # set how to choose training images
    train_every = max(1, num_total // num_train_per)
    start_frame_train = random.randint(0, train_every - 1)
    NO_OF_TRAINING_IMAGES = num_train_per * nvideo_train

    if list_Exp_ID_val is not None:
        # set how to choose validation images
        nvideo_val = len(list_Exp_ID_val)  # Number of validation videos
        # the total number of validation images is about 1/9 of the traning images
        num_val_per = int((num_train_per * nvideo_train / nvideo_val) // 9)
        num_val_per = min(num_val_per, num_total)
        val_every = num_total // num_val_per
        start_frame_val = random.randint(0, val_every - 1)
        NO_OF_VAL_IMAGES = num_val_per * nvideo_val

    # %% Load traiming images and masks from h5 files
    # training images
    train_imgs = np.zeros((num_train_per * nvideo_train, rows, cols),
                          dtype='float32')
    # temporal masks for training images
    train_masks = np.zeros((num_train_per * nvideo_train, rows, cols),
                           dtype='uint8')
    if list_Exp_ID_val is not None:
        # validation images
        val_imgs = np.zeros((num_val_per * nvideo_val, rows, cols),
                            dtype='float32')
        # temporal masks for validation images
        val_masks = np.zeros((num_val_per * nvideo_val, rows, cols),
                             dtype='uint8')

    print('Loading training images and masks.')
    # Select training images: for each video, start from frame "start_frame",
    # select a frame every "train_every" frames, totally "train_val_per" frames
    for cnt, Exp_ID in enumerate(list_Exp_ID_train):
        h5_img = h5py.File(os.path.join(dir_img, Exp_ID + '.h5'), 'r')
        h5_mask = h5py.File(os.path.join(dir_mask, Exp_ID + '.h5'), 'r')
        num_frame = h5_img['network_input'].shape[0]
        if num_frame >= num_train_per:
            train_imgs[cnt*num_train_per:(cnt+1)*num_train_per,:,:] \
                = np.array(h5_img['network_input'][start_frame_train:train_every*num_train_per:train_every])
            train_masks[cnt*num_train_per:(cnt+1)*num_train_per,:,:] \
                = np.array(h5_mask['temporal_masks'][start_frame_train:train_every*num_train_per:train_every])
        else:
            train_imgs = np.array(h5_img['network_input'])
            train_masks = np.array(h5_mask['temporal_masks'])
        h5_img.close()
        h5_mask.close()

    if list_Exp_ID_val is not None:
        # Select validation images: for each video, start from frame "start_frame",
        # select a frame every "val_every" frames, totally "num_val_per" frames
        for cnt, Exp_ID in enumerate(list_Exp_ID_val):
            h5_img = h5py.File(os.path.join(dir_img, Exp_ID + '.h5'), 'r')
            val_imgs[cnt*num_val_per:(cnt+1)*num_val_per,:,:] \
                = np.array(h5_img['network_input'][start_frame_val:val_every*num_val_per:val_every])
            h5_img.close()
            h5_mask = h5py.File(os.path.join(dir_mask, Exp_ID + '.h5'), 'r')
            val_masks[cnt*num_val_per:(cnt+1)*num_val_per,:,:] \
                = np.array(h5_mask['temporal_masks'][start_frame_val:val_every*num_val_per:val_every])
            h5_mask.close()

    # generater for training and validation images and masks
    train_gen = data_gen(train_imgs,
                         train_masks,
                         batch_size=BATCH_SIZE,
                         flips=True,
                         rotate=True)
    if list_Exp_ID_val is not None:
        val_gen = data_gen(val_imgs,
                           val_masks,
                           batch_size=BATCH_SIZE,
                           flips=False,
                           rotate=False)

    if list_Exp_ID_val is None:
        val_gen = None
        NO_OF_VAL_IMAGES = 0

    fff = get_shallow_unet(size=None, Params_loss=Params_loss)
    # The alternative line has more options to choose
    # fff = get_shallow_unet_more(size=None, n_depth=3, n_channel=4, skip=[1], activation='elu', Params_loss=Params_loss)
    if exist_model is not None:
        fff.load_weights(exist_model)

    class LossAndErrorPrintingCallback(tf.keras.callbacks.Callback):
        def on_epoch_end(self, epoch, logs=None):
            print('\n\nThe average loss for epoch {} is {:7.4f}.'.format(
                epoch, logs['loss']))

    # train CNN
    results = fff.fit_generator(
        train_gen,
        epochs=NO_OF_EPOCHS,
        steps_per_epoch=(NO_OF_TRAINING_IMAGES // BATCH_SIZE),
        validation_data=val_gen,
        validation_steps=(NO_OF_VAL_IMAGES // BATCH_SIZE),
        verbose=1,
        callbacks=[LossAndErrorPrintingCallback()])

    # save trained CNN model
    fff.save_weights(file_CNN)
    return results
Example #3
0
def suns_online_track(filename_video, filename_CNN, Params_pre, Params_post, dims, \
        frames_init, merge_every, batch_size_init=1, useSF=True, useTF=True, useSNR=True, \
        med_subtract=False, update_baseline=False, \
        useWT=False, prealloc=True, display=True, useMP=True, p=None):
    '''The complete SUNS online procedure with tracking.
        It uses the trained CNN model from "filename_CNN" and the optimized hyper-parameters in "Params_post"
        to process the video "Exp_ID" in "dir_video"

    Inputs: 
        filename_video (str): The path of the file of the input raw video.
            The file must be a ".h5" file, with dataset "mov" being the input video (shape = (T0,Lx0,Ly0)).
        filename_CNN (str): The path of the trained CNN model. 
        Params_pre (dict): Parameters for pre-processing.
            Params_pre['gauss_filt_size'] (float): The standard deviation of the spatial Gaussian filter in pixels
            Params_pre['Poisson_filt'] (1D numpy.ndarray of float32): The temporal filter kernel
            Params_pre['num_median_approx'] (int): Number of frames used to compute 
                the median and median-based standard deviation
            Params_pre['nn'] (int): Number of frames at the beginning of the video to be processed.
                The remaining video is not considered a part of the input video.
        Params_post (dict): Parameters for post-processing.
            Params_post['minArea']: Minimum area of a valid neuron mask (unit: pixels).
            Params_post['avgArea']: The typical neuron area (unit: pixels).
            Params_post['thresh_pmap']: The probablity threshold. Values higher than thresh_pmap are active pixels. 
                It is stored in uint8, so it should be converted to float32 before using.
            Params_post['thresh_mask']: Threashold to binarize the real-number mask.
            Params_post['thresh_COM0']: Threshold of COM distance (unit: pixels) used for the first COM-based merging. 
            Params_post['thresh_COM']: Threshold of COM distance (unit: pixels) used for the second COM-based merging. 
            Params_post['thresh_IOU']: Threshold of IOU used for merging neurons.
            Params_post['thresh_consume']: Threshold of consume ratio used for merging neurons.
            Params_post['cons']: Minimum number of consecutive frames that a neuron should be active for.
        dims (tuplel of int, shape = (2,)): lateral dimension of the raw video.
        frames_init (int): Number of frames used for initialization.
        merge_every (int): SUNS online merge the newly segmented frames every "merge_every" frames.
        batch_size_init (int, default to 1): batch size of CNN inference for initialization frames.
        useSF (bool, default to True): True if spatial filtering is used.
        useTF (bool, default to True): True if temporal filtering is used.
        useSNR (bool, default to True): True if pixel-by-pixel SNR normalization filtering is used.
        med_subtract (bool, default to False): True if the spatial median of every frame is subtracted before temporal filtering.
            Can only be used when spatial filtering is not used. 
        update_baseline (bool, default to False): True if the median and median-based std is updated every "frames_init" frames.
        useWT (bool, default to False): Indicator of whether watershed is used. 
        prealloc (bool, default to True): True if pre-allocate memory space for large variables. 
            Achieve faster speed at the cost of higher memory occupation.
        display (bool, default to True): Indicator of whether to show intermediate information
        useMP (bool, defaut to True): indicator of whether multiprocessing is used to speed up. 
        p (multiprocessing.Pool, default to None): 

    Outputs:
        Masks (3D numpy.ndarray of bool, shape = (n,Lx0,Ly0)): the final segmented masks. 
        Masks_2 (scipy.csr_matrix of bool, shape = (n,Lx0*Ly0)): the final segmented masks in the form of sparse matrix. 
        time_total (list of float, shape = (3,)): the total time spent 
            for initalization, online processing, and total processing
        time_frame (list of float, shape = (3,)): the average time spent on every frame
            for initalization, online processing, and total processing
        list_time_per (1D numpy.array of float): Time (s) spend on every frame during online processing
    '''
    if display:
        start = time.time()
    (Lx, Ly) = dims
    # zero-pad the lateral dimensions to multiples of 8, suitable for CNN
    rowspad = math.ceil(Lx/8)*8
    colspad = math.ceil(Ly/8)*8
    dimspad = (rowspad, colspad)

    Poisson_filt = Params_pre['Poisson_filt']
    gauss_filt_size = Params_pre['gauss_filt_size']
    nn = Params_pre['nn']
    leng_tf = Poisson_filt.size
    leng_past = 2*leng_tf # number of past frames stored for temporal filtering
    list_time_per = np.zeros(nn)

    # Load CNN model
    fff = get_shallow_unet()
    fff.load_weights(filename_CNN)
    # run CNN inference once to warm up
    init_imgs = np.zeros((batch_size_init, rowspad, colspad, 1), dtype='float32')
    init_masks = np.zeros((batch_size_init, rowspad, colspad, 1), dtype='uint8')
    fff.evaluate(init_imgs, init_masks, batch_size=batch_size_init)
    del init_imgs, init_masks

    # load optimal post-processing parameters
    minArea = Params_post['minArea']
    avgArea = Params_post['avgArea']
    # thresh_pmap = Params_post['thresh_pmap']
    thresh_mask = Params_post['thresh_mask']
    # thresh_COM0 = Params_post['thresh_COM0']
    # thresh_COM = Params_post['thresh_COM']
    thresh_IOU = Params_post['thresh_IOU']
    thresh_consume = Params_post['thresh_consume']
    # cons = Params_post['cons']
    # thresh_pmap_float = (Params_post['thresh_pmap']+1.5)/256
    thresh_pmap_float = (Params_post['thresh_pmap']+1)/256 # for published version


    # Spatial filtering preparation
    if useSF==True:
        # lateral dimensions slightly larger than the raw video but faster for FFT
        rows1 = cv2.getOptimalDFTSize(rowspad)
        cols1 = cv2.getOptimalDFTSize(colspad)
        
        # if the learned 2D and 3D wisdom files have been saved, load them. 
        # Otherwise, learn wisdom later
        Length_data2=str((rows1, cols1))
        cc2 = load_wisdom_txt('wisdom\\'+Length_data2)
        
        Length_data3=str((frames_init, rows1, cols1))
        cc3 = load_wisdom_txt('wisdom\\'+Length_data3)
        if cc3:
            pyfftw.import_wisdom(cc3)

        # mask for spatial filter
        mask2 = plan_mask2(dims, (rows1, cols1), gauss_filt_size)
        # FFT planning
        (bb, bf, fft_object_b, fft_object_c) = plan_fft(frames_init, (rows1, cols1), prealloc)
    else:
        (mask2, bf, fft_object_b, fft_object_c) = (None, None, None, None)
        bb=np.zeros((frames_init, rowspad, colspad), dtype='float32')

    # Temporal filtering preparation
    frames_initf = frames_init - leng_tf + 1
    if useTF==True:
        if prealloc:
            # past frames stored for temporal filtering
            past_frames = np.ones((leng_past, rowspad, colspad), dtype='float32')
        else:
            past_frames = np.zeros((leng_past, rowspad, colspad), dtype='float32')
    else:
        past_frames = None
    
    if prealloc: # Pre-allocate memory for some future variables
        med_frame2 = np.ones((rowspad, colspad, 2), dtype='float32')
        video_input = np.ones((frames_initf, rowspad, colspad), dtype='float32')        
        pmaps_b_init = np.ones((frames_initf, Lx, Ly), dtype='uint8')        
        frame_SNR = np.ones(dimspad, dtype='float32')
        pmaps_b = np.ones(dims, dtype='uint8')
        if update_baseline:
            video_tf_past = np.ones((frames_init, rowspad, colspad), dtype='float32')        
    else:
        med_frame2 = np.zeros((rowspad, colspad, 2), dtype='float32')
        video_input = np.zeros((frames_initf, rowspad, colspad), dtype='float32')        
        pmaps_b_init = np.zeros((frames_initf, Lx, Ly), dtype='uint8')        
        frame_SNR = np.zeros(dimspad, dtype='float32')
        pmaps_b = np.zeros(dims, dtype='uint8')
        if update_baseline:
            video_tf_past = np.zeros((frames_init, rowspad, colspad), dtype='float32')        

    if display:
        time_init = time.time()
        print('Parameter initialization time: {} s'.format(time_init-start))

    if update_baseline: # Set the parameters for median update
        distri_update = False
        if useSF: # If spatial filtering is used, the padded rows and columns are nonzero
            Lu2 = colspad
        else: # If spatial filtering is not used, we can ignore the padded rows and columns
            Lu2 = Ly
        Lu1 = int(np.round(frames_initf / Lu2))
        px_update = int(np.ceil(rowspad / Lu1))
        Lu = Lu1 * Lu2
        start_update = frames_initf + frames_init
        med_frame2_update = np.ones((px_update, 1, 2), dtype='float32')


    # %% Load raw video
    h5_img = h5py.File(filename_video, 'r')
    video_raw = np.array(h5_img['mov'])
    h5_img.close()
    nframes = video_raw.shape[0]
    nframesf = nframes - leng_tf + 1
    bb[:, :Lx, :Ly] = video_raw[:frames_init]
    if display:
        time_load = time.time()
        print('Load data: {} s'.format(time_load-time_init))


    # %% Actual processing starts after the video is loaded into memory
    # Initialization using the first "frames_init" frames
    print('Initialization of algorithms using the first {} frames'.format(frames_init))
    if display:
        start_init = time.time()
    med_frame3, segs_all, recent_frames = init_online(
        bb, dims, video_input, pmaps_b_init, fff, thresh_pmap_float, Params_post, \
        med_frame2, mask2, bf, fft_object_b, fft_object_c, Poisson_filt, \
        useSF=useSF, useTF=useTF, useSNR=useSNR, med_subtract=med_subtract, \
        useWT=useWT, batch_size_init=batch_size_init, p=p)
    if useTF==True:
        past_frames[:leng_tf] = recent_frames
    tuple_temp = merge_complete(segs_all[:frames_initf], dims, Params_post)

    # Initialize Online track variables
    (Masksb_temp, masks_temp, times_temp, area_temp, have_cons_temp) = tuple_temp
    # list of previously found neurons that satisfy consecutive frame requirement
    Masks_cons = select_cons(tuple_temp) 
    # sparse matrix of previously found neurons that satisfy consecutive frame requirement
    Masks_cons_2D = sparse.vstack(Masks_cons) 
    list_Masks_cons_2D = [Masks_cons_2D]
    # indices of previously found neurons that satisfy consecutive frame requirement
    ind_cons = have_cons_temp.nonzero()[0]
    segs0 = segs_all[0] # segs of initialization frames
    # segs if no neurons are found
    segs_empty = (segs0[0][0:0], segs0[1][0:0], segs0[2][0:0], segs0[3][0:0]) 
    # Number of previously found neurons that satisfy consecutive frame requirement
    N1 = len(Masks_cons)
    # list of "segs" for neurons that are not previously found
    list_segs_new = [] 
    # list of newly segmented masks for old neurons (segmented in previous frames)
    list_masks_old = [[] for _ in range(N1)] 
    # list of the newly active indices of frames of old neurons
    times_active_old = [[] for _ in range(N1)] 
    # True if the old neurons are active in the previous frame
    active_old_previous = np.zeros(N1, dtype='bool')

    if display:
        end_init = time.time()
        time_init = end_init-start_init
        time_frame_init = time_init/(frames_initf)*1000
        print('Initialization time: {:6f} s, {:6f} ms/frame'.format(time_init, time_frame_init))


    if display:
        start_online = time.time()
    # Spatial filtering preparation for online processing. 
    # Attention: this part counts to the total time
    if useSF:
        if cc2:
            pyfftw.import_wisdom(cc2)
        (bb, bf, fft_object_b, fft_object_c) = plan_fft2((rows1, cols1))
    else:
        (bf, fft_object_b, fft_object_c) = (None, None, None)
        bb=np.zeros(dimspad, dtype='float32')
    if update_baseline:
        med_frame3_temp = np.zeros_like(med_frame3)

    print('Start frame by frame processing')
    # %% Online processing for the following frames
    current_frame = leng_tf+1
    t_merge = frames_initf
    list_active_old = []

    for t in range(frames_initf, nframesf):
        if display:
            start_frame = time.time()
        # load the current frame
        bb[:Lx, :Ly] = video_raw[t+leng_tf-1]
        bb[Lx:, :] = 0
        bb[:, Ly:] = 0

        # PreProcessing
        frame_SNR, frame_tf = preprocess_online(bb, dimspad, med_frame3, frame_SNR, \
            past_frames[current_frame-leng_tf:current_frame], mask2, bf, fft_object_b, fft_object_c, \
            Poisson_filt, useSF=useSF, useTF=useTF, useSNR=useSNR, \
            med_subtract=med_subtract, update_baseline=update_baseline)

        if update_baseline:
            t_past = (t-frames_initf) % frames_init
            video_tf_past[t_past] = frame_tf
            if distri_update:
                # update median and median-based standard deviation distributedly, column by column
                nu = (t-start_update) % Lu
                nu1 = nu // Lu1
                nu2 = nu % Lu1
                med_frame3_temp[:, nu2*px_update:(nu2+1)*px_update, nu1:nu1+1] = median_calculation(
                    video_tf_past_fix[:, nu2*px_update:(nu2+1)*px_update, nu1:nu1+1], \
                    med_frame2_update, (px_update,1), 1, display=False)
                if nu == Lu-1:
                    med_frame3 = med_frame3_temp.copy()
                    (video_tf_past_fix, video_tf_past) = (video_tf_past, video_tf_past_fix)
            elif t >= start_update-1:
                distri_update = True
                (video_tf_past_fix, video_tf_past) = (video_tf_past, video_tf_past_fix)

        # CNN inference
        frame_prob = CNN_online(frame_SNR, fff, dims)

        # first step of post-processing
        segs = separate_neuron_online(frame_prob, pmaps_b, thresh_pmap_float, minArea, avgArea, useWT)

        active_old = np.zeros(N1, dtype='bool') # True if the old neurons are active in the current frame
        masks_t, neuronstate_t, cents_t, areas_t = segs
        N2 = neuronstate_t.size
        if N2: # Try to merge the new masks to old neurons
            new_found = np.zeros(N2, dtype='bool')
            for n2 in range(N2):
                masks_t2 = masks_t[n2]
                cents_t2 = np.round(cents_t[n2,1]) * Ly + np.round(cents_t[n2,0])  
                # If a new masks belongs to an old neuron, the COM of the new mask must be inside the old neuron area.
                # Select possible old neurons that the new mask can merge to
                possible_masks1 = Masks_cons_2D[:,cents_t2].nonzero()[0]
                IOUs = np.zeros(len(possible_masks1))
                areas_t2 = areas_t[n2]
                for (ind,n1) in enumerate(possible_masks1):
                    # Calculate IoU and consume ratio to determine merged neurons
                    area_i = Masks_cons[n1].multiply(masks_t2).nnz
                    area_temp1 = area_temp[n1]
                    area_u = area_temp1 + areas_t2 - area_i
                    IOU = area_i / area_u
                    consume = area_i / min(area_temp1, areas_t2)
                    contain = (IOU >= thresh_IOU) or (consume >= thresh_consume)
                    if contain: # merging criterion satisfied
                        IOUs[ind] = IOU
                num_contains = IOUs.nonzero()[0].size
                if num_contains: # The new mask can merge to one of the old neurons.
                    # If there are multiple candicates, choose the one with the highest IoU
                    belongs = possible_masks1[IOUs.argmax()]
                    # merge the mask and active frame index
                    list_masks_old[belongs].append(masks_t2)
                    times_active_old[belongs].append(t + frames_initf)
                    # This old neurons is active in the current frame
                    active_old[belongs] = True 
                else: # The new mask can not merge to any old neuron.
                    new_found[n2] = True

            if np.any(new_found): # There are some new masks that can not merge to old neurons
                segs_new = (masks_t[new_found], neuronstate_t[new_found], cents_t[new_found], areas_t[new_found])
            else: # All masks already merged to old neurons
                segs_new = segs_empty
                
        else: # No neurons fould in the current frame
            segs_new = segs
        list_segs_new.append(segs_new)

        if (t + 1 - t_merge) != merge_every or t == (nframesf-1):
            # Update the old neurons with new appearances in the current frame.
            if t < (nframesf-1):
                # True if the neurons are active in the previous frame but not active in the current frame
                inactive = np.logical_and(active_old_previous, np.logical_not(active_old)).nonzero()[0]
            else: # last frame
                # All active neurons should be updated, so they are treated as inactive in the next frame
                inactive = active_old_previous.nonzero()[0]

            # Update the indicators of the previous frame using the current frame
            active_old_previous = active_old.copy()
            for n1 in inactive: 
                # merge new active frames to existing active frames for already found neurons
                # n1 is the index in the old neurons that satisfy consecutive frame requirement. 
                # n10 is the index in all old neurons.
                n10 = ind_cons[n1] 
                # Add all the new masks to the overall real-number masks
                mask_update = masks_temp[n10] + sum(list_masks_old[n1])
                masks_temp[n10] = mask_update
                # Add indices of active frames
                times_add = np.unique(np.array(times_active_old[n1]))
                times_temp[n10] = np.hstack([times_temp[n10], times_add])
                # reset lists used to store the information from new frames related to old neurons
                list_masks_old[n1] = []
                times_active_old[n1] = []
                # update the binary masks and areas
                Maskb_update = mask_update >= mask_update.max() * thresh_mask
                Masksb_temp[n10] = Maskb_update
                Masks_cons[n1] = Maskb_update
                area_temp[n10] = Maskb_update.nnz
            if inactive.size:
                Masks_cons_2D = sparse.vstack(Masks_cons) 

        if (t + 1 - t_merge) == merge_every or t == (nframesf-1):
            if t < (nframesf-1):
                # delay merging new frame to next frame by assuming all the neurons active in the previous frame
                # are still active in the current frame, to reserve merging time for new neurons
                active_old_previous = np.logical_or(active_old_previous, active_old)

            # merge new neurons with old masks that do not satisfy consecutive frame requirement
            tuple_temp = (Masksb_temp, masks_temp, times_temp, area_temp, have_cons_temp)
            # merge the remaining new masks from the most recent "merge_every" frames
            tuple_add = merge_complete(list_segs_new, dims, Params_post)
            (Masksb_add, masks_add, times_add, area_add, have_cons_add) = tuple_add
            # update the indices of active frames
            times_add = [x + merge_every for x in times_add]
            tuple_add = (Masksb_add, masks_add, times_add, area_add, have_cons_add)
            # merge the remaining new masks with the existing masks that do not satisfy consecutive frame requirement
            tuple_temp = merge_2_nocons(tuple_temp, tuple_add, dims, Params_post)

            (Masksb_temp, masks_temp, times_temp, area_temp, have_cons_temp) = tuple_temp
            # Update the indices of old neurons that satisfy consecutive frame requirement
            ind_cons_new = have_cons_temp.nonzero()[0]
            for (ind,ind_cons_0) in enumerate(ind_cons_new):
                if ind_cons_0 not in ind_cons: 
                    # update lists used to store the information from new frames related to old neurons
                    if ind_cons_0 > ind_cons.max():
                        list_masks_old.append([])
                        times_active_old.append([])
                    else:
                        list_masks_old.insert(ind, [])
                        times_active_old.insert(ind, [])
            
            # Update the list of previously found neurons that satisfy consecutive frame requirement
            Masks_cons = select_cons(tuple_temp)
            Masks_cons_2D = sparse.vstack(Masks_cons) 
            N1 = len(Masks_cons)
            list_segs_new = []
            # Update whether the old neurons are active in the previous frame
            active_old_previous = np.zeros_like(have_cons_temp)
            active_old_previous[ind_cons] = active_old
            active_old_previous = active_old_previous[ind_cons_new]
            ind_cons = ind_cons_new
            t_merge += merge_every

        current_frame +=1
        # Update the stored latest frames when it runs out: move them "leng_tf" ahead
        if current_frame > leng_past:
            current_frame = leng_tf+1
            past_frames[:leng_tf] = past_frames[-leng_tf:]
        if display:
            end_frame = time.time()
            list_time_per[t] = end_frame - start_frame
        if t % 1000 == 0:
            print('{} frames have been processed'.format(t))
        list_Masks_cons_2D.append(Masks_cons_2D)
        list_active_old.append(active_old)

    Masks_cons = select_cons(tuple_temp)
    # final result. Masks_2 is a 2D sparse matrix of the segmented neurons
    if len(Masks_cons):
        Masks_2 = sparse.vstack(Masks_cons)
    else:
        Masks_2 = sparse.csc_matrix((0,dims[0]*dims[1]))
    list_Masks_cons_2D.append(Masks_2)

    if display:
        end_online = time.time()
        time_online = end_online-start_online
        time_frame_online = time_online/(nframesf-frames_initf)*1000
        print('Online time: {:6f} s, {:6f} ms/frame'.format(time_online, time_frame_online))

    # Save total processing time, and average processing time per frame
    if display:
        end_final = time.time()
        time_all = end_final-start_init
        time_frame_all = time_all/nframes*1000
        print('Total time: {:6f} s, {:6f} ms/frame'.format(time_all, time_frame_all))
        time_total = np.array([time_init, time_online, time_all])
        time_frame = np.array([time_frame_init, time_frame_online, time_frame_all])
    else:
        time_total = np.zeros((3,))
        time_frame = np.zeros((3,))

    # convert to a 3D array of the segmented neurons
    Masks = np.reshape(Masks_2.toarray(), (Masks_2.shape[0], Lx, Ly)).astype('bool')
    return Masks, Masks_2, time_total, time_frame, list_time_per, list_Masks_cons_2D, list_active_old
Example #4
0
def suns_batch(dir_video, Exp_ID, filename_CNN, Params_pre, Params_post, dims, \
        batch_size_eval=1, useSF=True, useTF=True, useSNR=True, med_subtract=False, \
        useWT=False, prealloc=True, display=True, useMP=True, p=None):
    '''The complete SUNS batch procedure.
        It uses the trained CNN model from "filename_CNN" and the optimized hyper-parameters in "Params_post"
        to process the video "Exp_ID" in "dir_video"

    Inputs: 
        dir_video (str): The folder containing the input video.
            Each file must be a ".h5" file, with dataset "mov" being the input video (shape = (T0,Lx0,Ly0)).
        Exp_ID (str): The filer name of the input raw video. 
        filename_CNN (str): The path of the trained CNN model. 
        Params_pre (dict): Parameters for pre-processing.
            Params_pre['gauss_filt_size'] (float): The standard deviation of the spatial Gaussian filter in pixels
            Params_pre['Poisson_filt'] (1D numpy.ndarray of float32): The temporal filter kernel
            Params_pre['num_median_approx'] (int): Number of frames used to compute 
                the median and median-based standard deviation
            Params_pre['nn'] (int): Number of frames at the beginning of the video to be processed.
                The remaining video is not considered a part of the input video.
        Params_post (dict): Parameters for post-processing.
            Params_post['minArea']: Minimum area of a valid neuron mask (unit: pixels).
            Params_post['avgArea']: The typical neuron area (unit: pixels).
            Params_post['thresh_pmap']: The probablity threshold. Values higher than thresh_pmap are active pixels. 
                It is stored in uint8, so it should be converted to float32 before using.
            Params_post['thresh_mask']: Threashold to binarize the real-number mask.
            Params_post['thresh_COM0']: Threshold of COM distance (unit: pixels) used for the first COM-based merging. 
            Params_post['thresh_COM']: Threshold of COM distance (unit: pixels) used for the second COM-based merging. 
            Params_post['thresh_IOU']: Threshold of IOU used for merging neurons.
            Params_post['thresh_consume']: Threshold of consume ratio used for merging neurons.
            Params_post['cons']: Minimum number of consecutive frames that a neuron should be active for.
        dims (tuplel of int, shape = (2,)): lateral dimension of the raw video.
        batch_size_eval (int, default to 1): batch size of CNN inference.
        useSF (bool, default to True): True if spatial filtering is used.
        useTF (bool, default to True): True if temporal filtering is used.
        useSNR (bool, default to True): True if pixel-by-pixel SNR normalization filtering is used.
        med_subtract (bool, default to False): True if the spatial median of every frame is subtracted before temporal filtering.
            Can only be used when spatial filtering is not used. 
        useWT (bool, default to False): Indicator of whether watershed is used. 
        prealloc (bool, default to True): True if pre-allocate memory space for large variables. 
            Achieve faster speed at the cost of higher memory occupation.
        display (bool, default to True): Indicator of whether to show intermediate information
        useMP (bool, defaut to True): indicator of whether multiprocessing is used to speed up. 
        p (multiprocessing.Pool, default to None): 

    Outputs:
        Masks (3D numpy.ndarray of bool, shape = (n,Lx0,Ly0)): the final segmented masks. 
        Masks_2 (scipy.csr_matrix of bool, shape = (n,Lx0*Ly0)): the final segmented masks in the form of sparse matrix. 
        time_total (list of float, shape = (4,)): the total time spent 
            for pre-processing, CNN inference, post-processing, and total processing
        time_frame (list of float, shape = (4,)): the average time spent on every frame
            for pre-processing, CNN inference, post-processing, and total processing
    '''
    if display:
        start = time.time()
    (Lx, Ly) = dims
    rowspad = math.ceil(Lx/8)*8
    colspad = math.ceil(Ly/8)*8
    # load CNN model
    fff = get_shallow_unet()
    fff.load_weights(filename_CNN)
    # run CNN inference once to warm up
    init_imgs = np.zeros((batch_size_eval, rowspad, colspad, 1), dtype='float32')
    init_masks = np.zeros((batch_size_eval, rowspad, colspad, 1), dtype='uint8')
    fff.evaluate(init_imgs, init_masks, batch_size=batch_size_eval)
    del init_imgs, init_masks

    # thresh_pmap_float = (Params_post['thresh_pmap']+1.5)/256
    thresh_pmap_float = (Params_post['thresh_pmap']+1)/256 # for published version
    if display:
        time_init = time.time()
        print('Initialization time: {} s'.format(time_init-start))

    # %% Actual processing starts after the video is loaded into memory
        # which is in the middle of "preprocess_video", represented by the output "start"
    # pre-processing including loading data
    video_input, start = preprocess_video(dir_video, Exp_ID, Params_pre, \
        useSF=useSF, useTF=useTF, useSNR=useSNR, med_subtract=med_subtract, prealloc=prealloc, display=display)
    nframes = video_input.shape[0]
    if display:
        end_pre = time.time()
        time_pre = end_pre-start
        time_frame_pre = time_pre/nframes*1000
        print('Pre-Processing time: {:6f} s, {:6f} ms/frame'.format(time_pre, time_frame_pre))

    # CNN inference
    video_input = np.expand_dims(video_input, axis=-1)
    prob_map = fff.predict(video_input, batch_size=batch_size_eval)
    if display:
        end_network = time.time()
        time_CNN = end_network-end_pre
        time_frame_CNN = time_CNN/nframes*1000
        print('CNN Infrence time: {:6f} s, {:6f} ms/frame'.format(time_CNN, time_frame_CNN))

    # post-processing
    prob_map = prob_map.squeeze()[:, :Lx, :Ly]
    print(Params_post)
    Params_post_copy = Params_post.copy()
    Params_post_copy['thresh_pmap'] = None # Avoid repeated thresholding in postprocessing
    pmaps_b = np.zeros(prob_map.shape, dtype='uint8')
    # threshold the probability map to binary activity
    fastthreshold(prob_map, pmaps_b, thresh_pmap_float)

    # the rest of post-processing. The result is a 2D sparse matrix of the segmented neurons
    Masks_2 = complete_segment(pmaps_b, Params_post_copy, display=display, p=p, useWT=useWT)
    if display:
        finish = time.time()
        time_post = finish-end_network
        time_frame_post = time_post/nframes*1000
        print('Post-Processing time: {:6f} s, {:6f} ms/frame'.format(time_post, time_frame_post))

    # convert to a 3D array of the segmented neurons
    Masks = np.reshape(Masks_2.toarray(), (Masks_2.shape[0], Lx, Ly)).astype('bool')

    # Save total processing time, and average processing time per frame
    if display:
        time_all = finish-start
        time_frame_all = time_all/nframes*1000
        print('Total time: {:6f} s, {:6f} ms/frame'.format(time_all, time_frame_all))
        time_total = np.array([time_pre, time_CNN, time_post, time_all])
        time_frame = np.array([time_frame_pre, time_frame_CNN, time_frame_post, time_frame_all])
    else:
        time_total = np.zeros((4,))
        time_frame = np.zeros((4,))

    return Masks, Masks_2, time_total, time_frame
Example #5
0
def suns_online(filename_video, filename_CNN, Params_pre, Params_post, dims, \
        frames_init, merge_every, batch_size_init=1, useSF=True, useTF=True, useSNR=True, \
        med_subtract=False, update_baseline=False, \
        useWT=False, show_intermediate=True, prealloc=True, display=True, useMP=True, p=None):
    '''The complete SUNS online procedure.
        It uses the trained CNN model from "filename_CNN" and the optimized hyper-parameters in "Params_post"
        to process the video "Exp_ID" in "dir_video"

    Inputs: 
        filename_video (str): The path of the file of the input raw video.
            The file must be a ".h5" file, with dataset "mov" being the input video (shape = (T0,Lx0,Ly0)).
        filename_CNN (str): The path of the trained CNN model. 
        Params_pre (dict): Parameters for pre-processing.
            Params_pre['gauss_filt_size'] (float): The standard deviation of the spatial Gaussian filter in pixels
            Params_pre['Poisson_filt'] (1D numpy.ndarray of float32): The temporal filter kernel
            Params_pre['num_median_approx'] (int): Number of frames used to compute 
                the median and median-based standard deviation
            Params_pre['nn'] (int): Number of frames at the beginning of the video to be processed.
                The remaining video is not considered a part of the input video.
        Params_post (dict): Parameters for post-processing.
            Params_post['minArea']: Minimum area of a valid neuron mask (unit: pixels).
            Params_post['avgArea']: The typical neuron area (unit: pixels).
            Params_post['thresh_pmap']: The probablity threshold. Values higher than thresh_pmap are active pixels. 
                It is stored in uint8, so it should be converted to float32 before using.
            Params_post['thresh_mask']: Threashold to binarize the real-number mask.
            Params_post['thresh_COM0']: Threshold of COM distance (unit: pixels) used for the first COM-based merging. 
            Params_post['thresh_COM']: Threshold of COM distance (unit: pixels) used for the second COM-based merging. 
            Params_post['thresh_IOU']: Threshold of IOU used for merging neurons.
            Params_post['thresh_consume']: Threshold of consume ratio used for merging neurons.
            Params_post['cons']: Minimum number of consecutive frames that a neuron should be active for.
        dims (tuplel of int, shape = (2,)): lateral dimension of the raw video.
        frames_init (int): Number of frames used for initialization.
        merge_every (int): SUNS online merge the newly segmented frames every "merge_every" frames.
        batch_size_init (int, default to 1): batch size of CNN inference for initialization frames.
        useSF (bool, default to True): True if spatial filtering is used.
        useTF (bool, default to True): True if temporal filtering is used.
        useSNR (bool, default to True): True if pixel-by-pixel SNR normalization filtering is used.
        med_subtract (bool, default to False): True if the spatial median of every frame is subtracted before temporal filtering.
            Can only be used when spatial filtering is not used. 
        update_baseline (bool, default to False): True if the median and median-based std is updated every "frames_init" frames.
        useWT (bool, default to False): Indicator of whether watershed is used. 
        show_intermediate (bool, default to True): Indicator of whether 
            consecutive frame requirement is applied to screen neurons after every update. 
        prealloc (bool, default to True): True if pre-allocate memory space for large variables. 
            Achieve faster speed at the cost of higher memory occupation.
        display (bool, default to True): Indicator of whether to show intermediate information
        useMP (bool, defaut to True): indicator of whether multiprocessing is used to speed up. 
        p (multiprocessing.Pool, default to None): 

    Outputs:
        Masks (3D numpy.ndarray of bool, shape = (n,Lx0,Ly0)): the final segmented masks. 
        Masks_2 (scipy.csr_matrix of bool, shape = (n,Lx0*Ly0)): the final segmented masks in the form of sparse matrix. 
        time_total (list of float, shape = (3,)): the total time spent 
            for initalization, online processing, and total processing
        time_frame (list of float, shape = (3,)): the average time spent on every frame
            for initalization, online processing, and total processing
        list_time_per (1D numpy.array of float): Time (s) spend on every frame during online processing
    '''
    if display:
        start = time.time()
    (Lx, Ly) = dims
    # zero-pad the lateral dimensions to multiples of 8, suitable for CNN
    rowspad = math.ceil(Lx/8)*8
    colspad = math.ceil(Ly/8)*8
    dimspad = (rowspad, colspad)

    Poisson_filt = Params_pre['Poisson_filt']
    gauss_filt_size = Params_pre['gauss_filt_size']
    nn = Params_pre['nn']
    leng_tf = Poisson_filt.size
    leng_past = 2*leng_tf # number of past frames stored for temporal filtering
    list_time_per = np.zeros(nn)

    # Load CNN model
    fff = get_shallow_unet()
    fff.load_weights(filename_CNN)
    # run CNN inference once to warm up
    init_imgs = np.zeros((batch_size_init, rowspad, colspad, 1), dtype='float32')
    init_masks = np.zeros((batch_size_init, rowspad, colspad, 1), dtype='uint8')
    fff.evaluate(init_imgs, init_masks, batch_size=batch_size_init)
    del init_imgs, init_masks

    # load optimal post-processing parameters
    minArea = Params_post['minArea']
    avgArea = Params_post['avgArea']
    # thresh_pmap = Params_post['thresh_pmap']
    thresh_mask = Params_post['thresh_mask']
    thresh_COM0 = Params_post['thresh_COM0']
    thresh_COM = Params_post['thresh_COM']
    thresh_IOU = Params_post['thresh_IOU']
    thresh_consume = Params_post['thresh_consume']
    cons = Params_post['cons']
    # thresh_pmap_float = (Params_post['thresh_pmap']+1.5)/256
    thresh_pmap_float = (Params_post['thresh_pmap']+1)/256 # for published version


    # Spatial filtering preparation
    if useSF==True:
        # lateral dimensions slightly larger than the raw video but faster for FFT
        rows1 = cv2.getOptimalDFTSize(rowspad)
        cols1 = cv2.getOptimalDFTSize(colspad)
        
        # if the learned 2D and 3D wisdom files have been saved, load them. 
        # Otherwise, learn wisdom later
        Length_data2=str((rows1, cols1))
        cc2 = load_wisdom_txt('wisdom\\'+Length_data2)
        
        Length_data3=str((frames_init, rows1, cols1))
        cc3 = load_wisdom_txt('wisdom\\'+Length_data3)
        if cc3:
            pyfftw.import_wisdom(cc3)

        # mask for spatial filter
        mask2 = plan_mask2(dims, (rows1, cols1), gauss_filt_size)
        # FFT planning
        (bb, bf, fft_object_b, fft_object_c) = plan_fft(frames_init, (rows1, cols1), prealloc)
    else:
        (mask2, bf, fft_object_b, fft_object_c) = (None, None, None, None)
        bb=np.zeros((frames_init, rowspad, colspad), dtype='float32')

    # Temporal filtering preparation
    frames_initf = frames_init - leng_tf + 1
    if useTF==True:
        if prealloc:
            # past frames stored for temporal filtering
            past_frames = np.ones((leng_past, rowspad, colspad), dtype='float32')
        else:
            past_frames = np.zeros((leng_past, rowspad, colspad), dtype='float32')
    else:
        past_frames = None
    
    if prealloc: # Pre-allocate memory for some future variables
        med_frame2 = np.ones((rowspad, colspad, 2), dtype='float32')
        video_input = np.ones((frames_initf, rowspad, colspad), dtype='float32')        
        pmaps_b_init = np.ones((frames_initf, Lx, Ly), dtype='uint8')        
        frame_SNR = np.ones(dimspad, dtype='float32')
        pmaps_b = np.ones(dims, dtype='uint8')
        if update_baseline:
            video_tf_past = np.ones((frames_init, rowspad, colspad), dtype='float32')        
            video_tf_past_fix = np.zeros((frames_init, rowspad, colspad), dtype='float32')        
    else:
        med_frame2 = np.zeros((rowspad, colspad, 2), dtype='float32')
        video_input = np.zeros((frames_initf, rowspad, colspad), dtype='float32')        
        pmaps_b_init = np.zeros((frames_initf, Lx, Ly), dtype='uint8')        
        frame_SNR = np.zeros(dimspad, dtype='float32')
        pmaps_b = np.zeros(dims, dtype='uint8')
        if update_baseline:
            video_tf_past = np.zeros((frames_init, rowspad, colspad), dtype='float32')        
            video_tf_past_fix = np.zeros((frames_init, rowspad, colspad), dtype='float32')        

    if display:
        time_init = time.time()
        print('Parameter initialization time: {} s'.format(time_init-start))

    if update_baseline: # Set the parameters for median update
        distri_update = False
        if useSF: # If spatial filtering is used, the padded rows and columns are nonzero
            Lu2 = colspad
        else: # If spatial filtering is not used, we can ignore the padded rows and columns
            Lu2 = Ly
        Lu1 = int(np.round(frames_initf / Lu2))
        px_update = int(np.ceil(rowspad / Lu1))
        Lu = Lu1 * Lu2
        start_update = frames_initf + frames_init
        med_frame2_update = np.ones((px_update, 1, 2), dtype='float32')


    # %% Load raw video
    h5_img = h5py.File(filename_video, 'r')
    video_raw = np.array(h5_img['mov'])
    h5_img.close()
    nframes = video_raw.shape[0]
    nframesf = nframes - leng_tf + 1
    bb[:, :Lx, :Ly] = video_raw[:frames_init]
    if display:
        time_load = time.time()
        print('Load data: {} s'.format(time_load-time_init))


    # %% Actual processing starts after the video is loaded into memory
    # Initialization using the first "frames_init" frames
    print('Initialization of algorithms using the first {} frames'.format(frames_init))
    if display:
        start_init = time.time()
    med_frame3, segs_all, recent_frames = init_online(
        bb, dims, video_input, pmaps_b_init, fff, thresh_pmap_float, Params_post, \
        med_frame2, mask2, bf, fft_object_b, fft_object_c, Poisson_filt, \
        useSF=useSF, useTF=useTF, useSNR=useSNR, med_subtract=med_subtract, \
        useWT=useWT, batch_size_init=batch_size_init, p=p)
    if useTF==True:
        past_frames[:leng_tf] = recent_frames
    tuple_temp = merge_complete(segs_all[:frames_initf], dims, Params_post)
    if show_intermediate:
        Masks_2 = select_cons(tuple_temp)
        if len(Masks_2):
            Masks_2 = sparse.vstack(Masks_2)
        else:
            Masks_2 = sparse.csc_matrix((0,dims[0]*dims[1]))
        list_Masks_2 = [Masks_2]

    if display:
        end_init = time.time()
        time_init = end_init-start_init
        time_frame_init = time_init/(frames_initf)*1000
        print('Initialization time: {:6f} s, {:6f} ms/frame'.format(time_init, time_frame_init))


    if display:
        start_online = time.time()
    # Spatial filtering preparation for online processing. 
    # Attention: this part counts to the total time
    if useSF:
        if cc2:
            pyfftw.import_wisdom(cc2)
        (bb, bf, fft_object_b, fft_object_c) = plan_fft2((rows1, cols1))
    else:
        (bf, fft_object_b, fft_object_c) = (None, None, None)
        bb=np.zeros(dimspad, dtype='float32')
    if update_baseline:
        med_frame3_temp = np.zeros_like(med_frame3)
    
    print('Start frame by frame processing')
    # %% Online processing for the following frames
    current_frame = leng_tf+1
    t_merge = frames_initf

    for t in range(frames_initf,nframesf):
        if display:
            start_frame = time.time()
        # load the current frame
        bb[:Lx, :Ly] = video_raw[t+leng_tf-1]
        bb[Lx:, :] = 0
        bb[:, Ly:] = 0
        
        # PreProcessing
        frame_SNR, frame_tf = preprocess_online(bb, dimspad, med_frame3, frame_SNR, \
            past_frames[current_frame-leng_tf:current_frame], mask2, bf, fft_object_b, \
            fft_object_c, Poisson_filt, useSF=useSF, useTF=useTF, useSNR=useSNR, \
            med_subtract=med_subtract, update_baseline=update_baseline)

        if update_baseline:
            t_past = (t-frames_initf) % frames_init
            video_tf_past[t_past] = frame_tf
            if distri_update:
                # update median and median-based standard deviation distributedly, column by column
                nu = (t-start_update) % Lu
                nu1 = nu // Lu1
                nu2 = nu % Lu1
                med_frame3_temp[:, nu2*px_update:(nu2+1)*px_update, nu1:nu1+1] = median_calculation(
                    video_tf_past_fix[:, nu2*px_update:(nu2+1)*px_update, nu1:nu1+1], \
                    med_frame2_update, (px_update,1), 1, display=False)
                if nu == Lu-1:
                    med_frame3 = med_frame3_temp.copy()
                    (video_tf_past_fix, video_tf_past) = (video_tf_past, video_tf_past_fix)
            elif t >= start_update-1:
                distri_update = True
                (video_tf_past_fix, video_tf_past) = (video_tf_past, video_tf_past_fix)

        # CNN inference
        frame_prob = CNN_online(frame_SNR, fff, dims)

        # first step of post-processing
        segs = separate_neuron_online(frame_prob, pmaps_b, thresh_pmap_float, minArea, avgArea, useWT)
        segs_all.append(segs)

        # temporal merging 1: combine neurons with COM distance smaller than thresh_COM0
        if ((t + 1 - t_merge) == merge_every) or (t==nframesf-1):
            # uniques, times_uniques = unique_neurons1_simp(segs_all[t_merge:], thresh_COM0) # minArea,
            totalmasks, neuronstate, COMs, areas, probmapID = segs_results(segs_all[t_merge:])
            uniques, times_uniques = unique_neurons2_simp(totalmasks, neuronstate, COMs, \
                areas, probmapID, minArea=0, thresh_COM0=thresh_COM0, useMP=useMP)

        # temporal merging 2: combine neurons with COM distance smaller than thresh_COM
        if ((t - 0 - t_merge) == merge_every) or (t==nframesf-1):
            if uniques.size:
                groupedneurons, times_groupedneurons = \
                    group_neurons(uniques, thresh_COM, thresh_mask, dims, times_uniques, useMP=useMP)

        # temporal merging 3: combine neurons with IoU larger than thresh_IOU
        if ((t - 1 - t_merge) == merge_every) or (t==nframesf-1):
            if uniques.size:
                piecedneurons_1, times_piecedneurons_1 = \
                    piece_neurons_IOU(groupedneurons, thresh_mask, thresh_IOU, times_groupedneurons)

        # temporal merging 4: combine neurons with conumse ratio larger than thresh_consume
        if ((t - 2 - t_merge) == merge_every) or (t==nframesf-1):
            if uniques.size:
                piecedneurons, times_piecedneurons = \
                    piece_neurons_consume(piecedneurons_1, avgArea, thresh_mask, thresh_consume, times_piecedneurons_1)
                # masks of new neurons
                masks_add = piecedneurons
                # indices of frames when the neurons are active
                times_add = [np.unique(x) + t_merge for x in times_piecedneurons]
                    
                # Refine neurons using consecutive occurence
                if masks_add.size:
                    # new real-number masks
                    masks_add = [x for x in masks_add]
                    # new binary masks
                    Masksb_add = [(x >= x.max() * thresh_mask).astype('float') for x in masks_add]
                    # areas of new masks
                    area_add = np.array([x.nnz for x in Masksb_add])
                    # indicators of whether the new masks satisfy consecutive frame requirement
                    have_cons_add = refine_seperate_cons_online(times_add, cons)
                else:
                    Masksb_add = []
                    area_add = np.array([])
                    have_cons_add = np.array([])

            else: # does not find any active neuron
                Masksb_add = []
                masks_add = []
                times_add = times_uniques
                area_add = np.array([])
                have_cons_add = np.array([])
            tuple_add = (Masksb_add, masks_add, times_add, area_add, have_cons_add)

        # temporal merging 5: merge newly found neurons within the recent "merge_every" frames with existing neurons
        if ((t - 3 - t_merge) == merge_every) or (t==nframesf-1):
            tuple_temp = merge_2(tuple_temp, tuple_add, dims, Params_post)
            t_merge += merge_every
            if show_intermediate:
                Masks_2 = select_cons(tuple_temp)
                if len(Masks_2):
                    Masks_2 = sparse.vstack(Masks_2)
                else:
                    Masks_2 = sparse.csc_matrix((0,dims[0]*dims[1]))
                list_Masks_2.append(Masks_2)

        current_frame +=1
        # Update the stored latest frames when it runs out: move them "leng_tf" ahead
        if current_frame > leng_past:
            current_frame = leng_tf+1
            past_frames[:leng_tf] = past_frames[-leng_tf:]
        if display:
            end_frame = time.time()
            list_time_per[t] = end_frame - start_frame
        if t % 1000 == 0:
            print('{} frames have been processed'.format(t))

    if not show_intermediate:
        Masks_2 = select_cons(tuple_temp)
        # final result. Masks_2 is a 2D sparse matrix of the segmented neurons
        if len(Masks_2):
            Masks_2 = sparse.vstack(Masks_2)
        else:
            Masks_2 = sparse.csc_matrix((0,dims[0]*dims[1]))
    list_Masks_2.append(Masks_2)

    if display:
        end_online = time.time()
        time_online = end_online-start_online
        time_frame_online = time_online/(nframesf-frames_initf)*1000
        print('Online time: {:6f} s, {:6f} ms/frame'.format(time_online, time_frame_online))

    # Save total processing time, and average processing time per frame
    if display:
        end_final = time.time()
        time_all = end_final-start_init
        time_frame_all = time_all/nframes*1000
        print('Total time: {:6f} s, {:6f} ms/frame'.format(time_all, time_frame_all))
        time_total = np.array([time_init, time_online, time_all])
        time_frame = np.array([time_frame_init, time_frame_online, time_frame_all])
    else:
        time_total = np.zeros((3,))
        time_frame = np.zeros((3,))

    # convert to a 3D array of the segmented neurons
    Masks = np.reshape(Masks_2.toarray(), (Masks_2.shape[0], Lx, Ly)).astype('bool')
    return Masks, Masks_2, time_total, time_frame, list_time_per, list_Masks_2