コード例 #1
0
def suns_online_track(filename_video, filename_CNN, Params_pre, Params_post, dims, \
        frames_init, merge_every, batch_size_init=1, useSF=True, useTF=True, useSNR=True, \
        med_subtract=False, update_baseline=False, \
        useWT=False, prealloc=True, display=True, useMP=True, p=None):
    '''The complete SUNS online procedure with tracking.
        It uses the trained CNN model from "filename_CNN" and the optimized hyper-parameters in "Params_post"
        to process the video "Exp_ID" in "dir_video"

    Inputs: 
        filename_video (str): The path of the file of the input raw video.
            The file must be a ".h5" file, with dataset "mov" being the input video (shape = (T0,Lx0,Ly0)).
        filename_CNN (str): The path of the trained CNN model. 
        Params_pre (dict): Parameters for pre-processing.
            Params_pre['gauss_filt_size'] (float): The standard deviation of the spatial Gaussian filter in pixels
            Params_pre['Poisson_filt'] (1D numpy.ndarray of float32): The temporal filter kernel
            Params_pre['num_median_approx'] (int): Number of frames used to compute 
                the median and median-based standard deviation
            Params_pre['nn'] (int): Number of frames at the beginning of the video to be processed.
                The remaining video is not considered a part of the input video.
        Params_post (dict): Parameters for post-processing.
            Params_post['minArea']: Minimum area of a valid neuron mask (unit: pixels).
            Params_post['avgArea']: The typical neuron area (unit: pixels).
            Params_post['thresh_pmap']: The probablity threshold. Values higher than thresh_pmap are active pixels. 
                It is stored in uint8, so it should be converted to float32 before using.
            Params_post['thresh_mask']: Threashold to binarize the real-number mask.
            Params_post['thresh_COM0']: Threshold of COM distance (unit: pixels) used for the first COM-based merging. 
            Params_post['thresh_COM']: Threshold of COM distance (unit: pixels) used for the second COM-based merging. 
            Params_post['thresh_IOU']: Threshold of IOU used for merging neurons.
            Params_post['thresh_consume']: Threshold of consume ratio used for merging neurons.
            Params_post['cons']: Minimum number of consecutive frames that a neuron should be active for.
        dims (tuplel of int, shape = (2,)): lateral dimension of the raw video.
        frames_init (int): Number of frames used for initialization.
        merge_every (int): SUNS online merge the newly segmented frames every "merge_every" frames.
        batch_size_init (int, default to 1): batch size of CNN inference for initialization frames.
        useSF (bool, default to True): True if spatial filtering is used.
        useTF (bool, default to True): True if temporal filtering is used.
        useSNR (bool, default to True): True if pixel-by-pixel SNR normalization filtering is used.
        med_subtract (bool, default to False): True if the spatial median of every frame is subtracted before temporal filtering.
            Can only be used when spatial filtering is not used. 
        update_baseline (bool, default to False): True if the median and median-based std is updated every "frames_init" frames.
        useWT (bool, default to False): Indicator of whether watershed is used. 
        prealloc (bool, default to True): True if pre-allocate memory space for large variables. 
            Achieve faster speed at the cost of higher memory occupation.
        display (bool, default to True): Indicator of whether to show intermediate information
        useMP (bool, defaut to True): indicator of whether multiprocessing is used to speed up. 
        p (multiprocessing.Pool, default to None): 

    Outputs:
        Masks (3D numpy.ndarray of bool, shape = (n,Lx0,Ly0)): the final segmented masks. 
        Masks_2 (scipy.csr_matrix of bool, shape = (n,Lx0*Ly0)): the final segmented masks in the form of sparse matrix. 
        time_total (list of float, shape = (3,)): the total time spent 
            for initalization, online processing, and total processing
        time_frame (list of float, shape = (3,)): the average time spent on every frame
            for initalization, online processing, and total processing
    '''
    if display:
        start = time.time()
    (Lx, Ly) = dims
    # zero-pad the lateral dimensions to multiples of 8, suitable for CNN
    rowspad = math.ceil(Lx / 8) * 8
    colspad = math.ceil(Ly / 8) * 8
    dimspad = (rowspad, colspad)

    Poisson_filt = Params_pre['Poisson_filt']
    gauss_filt_size = Params_pre['gauss_filt_size']
    nn = Params_pre['nn']
    leng_tf = Poisson_filt.size
    leng_past = 2 * leng_tf  # number of past frames stored for temporal filtering
    list_time_per = np.zeros(nn)

    # Load CNN model
    fff = get_shallow_unet()
    fff.load_weights(filename_CNN)
    # run CNN inference once to warm up
    init_imgs = np.zeros((batch_size_init, rowspad, colspad, 1),
                         dtype='float32')
    init_masks = np.zeros((batch_size_init, rowspad, colspad, 1),
                          dtype='uint8')
    fff.evaluate(init_imgs, init_masks, batch_size=batch_size_init)
    del init_imgs, init_masks

    # load optimal post-processing parameters
    minArea = Params_post['minArea']
    avgArea = Params_post['avgArea']
    # thresh_pmap = Params_post['thresh_pmap']
    thresh_mask = Params_post['thresh_mask']
    # thresh_COM0 = Params_post['thresh_COM0']
    # thresh_COM = Params_post['thresh_COM']
    thresh_IOU = Params_post['thresh_IOU']
    thresh_consume = Params_post['thresh_consume']
    # cons = Params_post['cons']
    # thresh_pmap_float = (Params_post['thresh_pmap']+1.5)/256
    thresh_pmap_float = (Params_post['thresh_pmap'] +
                         1) / 256  # for published version

    # Spatial filtering preparation
    if useSF == True:
        # lateral dimensions slightly larger than the raw video but faster for FFT
        rows1 = cv2.getOptimalDFTSize(rowspad)
        cols1 = cv2.getOptimalDFTSize(colspad)

        # if the learned 2D and 3D wisdom files have been saved, load them.
        # Otherwise, learn wisdom later
        Length_data2 = str((rows1, cols1))
        cc2 = load_wisdom_txt('wisdom\\' + Length_data2)

        Length_data3 = str((frames_init, rows1, cols1))
        cc3 = load_wisdom_txt('wisdom\\' + Length_data3)
        if cc3:
            pyfftw.import_wisdom(cc3)

        # mask for spatial filter
        mask2 = plan_mask2(dims, (rows1, cols1), gauss_filt_size)
        # FFT planning
        (bb, bf, fft_object_b,
         fft_object_c) = plan_fft(frames_init, (rows1, cols1), prealloc)
    else:
        (mask2, bf, fft_object_b, fft_object_c) = (None, None, None, None)
        bb = np.zeros((frames_init, rowspad, colspad), dtype='float32')

    # Temporal filtering preparation
    frames_initf = frames_init - leng_tf + 1
    if useTF == True:
        if prealloc:
            # past frames stored for temporal filtering
            past_frames = np.ones((leng_past, rowspad, colspad),
                                  dtype='float32')
        else:
            past_frames = np.zeros((leng_past, rowspad, colspad),
                                   dtype='float32')
    else:
        past_frames = None

    if prealloc:  # Pre-allocate memory for some future variables
        med_frame2 = np.ones((rowspad, colspad, 2), dtype='float32')
        video_input = np.ones((frames_initf, rowspad, colspad),
                              dtype='float32')
        pmaps_b_init = np.ones((frames_initf, Lx, Ly), dtype='uint8')
        frame_SNR = np.ones(dimspad, dtype='float32')
        pmaps_b = np.ones(dims, dtype='uint8')
        if update_baseline:
            video_tf_past = np.ones((frames_init, rowspad, colspad),
                                    dtype='float32')
    else:
        med_frame2 = np.zeros((rowspad, colspad, 2), dtype='float32')
        video_input = np.zeros((frames_initf, rowspad, colspad),
                               dtype='float32')
        pmaps_b_init = np.zeros((frames_initf, Lx, Ly), dtype='uint8')
        frame_SNR = np.zeros(dimspad, dtype='float32')
        pmaps_b = np.zeros(dims, dtype='uint8')
        if update_baseline:
            video_tf_past = np.zeros((frames_init, rowspad, colspad),
                                     dtype='float32')

    if display:
        time_init = time.time()
        print('Parameter initialization time: {} s'.format(time_init - start))

    # %% Load raw video
    h5_img = h5py.File(filename_video, 'r')
    video_raw = np.array(h5_img['mov'])
    h5_img.close()
    nframes = video_raw.shape[0]
    nframesf = nframes - leng_tf + 1
    bb[:, :Lx, :Ly] = video_raw[:frames_init]
    if display:
        time_load = time.time()
        print('Load data: {} s'.format(time_load - time_init))

    # %% Actual processing starts after the video is loaded into memory
    # Initialization using the first "frames_init" frames
    print('Initialization of algorithms using the first {} frames'.format(
        frames_init))
    if display:
        start_init = time.time()
    med_frame3, segs_all, recent_frames = init_online(
        bb, dims, video_input, pmaps_b_init, fff, thresh_pmap_float, Params_post, \
        med_frame2, mask2, bf, fft_object_b, fft_object_c, Poisson_filt, \
        useSF=useSF, useTF=useTF, useSNR=useSNR, med_subtract=med_subtract, \
        useWT=useWT, batch_size_init=batch_size_init, p=p)
    if useTF == True:
        past_frames[:leng_tf] = recent_frames
    tuple_temp = merge_complete(segs_all[:frames_initf], dims, Params_post)

    # Initialize Online track variables
    (Masksb_temp, masks_temp, times_temp, area_temp,
     have_cons_temp) = tuple_temp
    # list of previously found neurons that satisfy consecutive frame requirement
    Masks_cons = select_cons(tuple_temp)
    # sparse matrix of previously found neurons that satisfy consecutive frame requirement
    Masks_cons_2D = sparse.vstack(Masks_cons)
    # indices of previously found neurons that satisfy consecutive frame requirement
    ind_cons = have_cons_temp.nonzero()[0]
    segs0 = segs_all[0]  # segs of initialization frames
    # segs if no neurons are found
    segs_empty = (segs0[0][0:0], segs0[1][0:0], segs0[2][0:0], segs0[3][0:0])
    # Number of previously found neurons that satisfy consecutive frame requirement
    N1 = len(Masks_cons)
    # list of "segs" for neurons that are not previously found
    list_segs_new = []
    # list of newly segmented masks for old neurons (segmented in previous frames)
    list_masks_old = [[] for _ in range(N1)]
    # list of the newly active indices of frames of old neurons
    times_active_old = [[] for _ in range(N1)]
    # True if the old neurons are active in the previous frame
    active_old_previous = np.zeros(N1, dtype='bool')

    if display:
        end_init = time.time()
        time_init = end_init - start_init
        time_frame_init = time_init / (frames_initf) * 1000
        print('Initialization time: {:6f} s, {:6f} ms/frame'.format(
            time_init, time_frame_init))

    if display:
        start_online = time.time()
    # Spatial filtering preparation for online processing.
    # Attention: this part counts to the total time
    if useSF:
        if cc2:
            pyfftw.import_wisdom(cc2)
        (bb, bf, fft_object_b, fft_object_c) = plan_fft2((rows1, cols1))
    else:
        (bf, fft_object_b, fft_object_c) = (None, None, None)
        bb = np.zeros(dimspad, dtype='float32')

    print('Start frame by frame processing')
    # %% Online processing for the following frames
    current_frame = leng_tf + 1
    t_merge = frames_initf
    for t in range(frames_initf, nframesf):
        if display:
            start_frame = time.time()
        # load the current frame
        bb[:Lx, :Ly] = video_raw[t + leng_tf - 1]
        bb[Lx:, :] = 0
        bb[:, Ly:] = 0

        # PreProcessing
        frame_SNR, frame_tf = preprocess_online(bb, dimspad, med_frame3, frame_SNR, \
            past_frames[current_frame-leng_tf:current_frame], mask2, bf, fft_object_b, fft_object_c, \
            Poisson_filt, useSF=useSF, useTF=useTF, useSNR=useSNR, \
            med_subtract=med_subtract, update_baseline=update_baseline)

        if update_baseline:
            t_past = (t - frames_initf) % frames_init
            video_tf_past[t_past] = frame_tf
            if t_past == frames_init - 1:
                # update median and median-based standard deviation every "frames_init" frames
                if useSNR:
                    med_frame3 = SNR_normalization(video_tf_past,
                                                   med_frame2,
                                                   (rowspad, colspad),
                                                   1,
                                                   display=False)
                else:
                    med_frame3 = median_normalization(video_tf_past,
                                                      med_frame2,
                                                      (rowspad, colspad),
                                                      1,
                                                      display=False)

        # CNN inference
        frame_prob = CNN_online(frame_SNR, fff, dims)

        # first step of post-processing
        segs = separate_neuron_online(frame_prob, pmaps_b, thresh_pmap_float,
                                      minArea, avgArea, useWT)

        active_old = np.zeros(
            N1, dtype='bool'
        )  # True if the old neurons are active in the current frame
        masks_t, neuronstate_t, cents_t, areas_t = segs
        N2 = neuronstate_t.size
        if N2:  # Try to merge the new masks to old neurons
            new_found = np.zeros(N2, dtype='bool')
            for n2 in range(N2):
                masks_t2 = masks_t[n2]
                cents_t2 = np.round(cents_t[n2, 1]) * Ly + np.round(cents_t[n2,
                                                                            0])
                # If a new masks belongs to an old neuron, the COM of the new mask must be inside the old neuron area.
                # Select possible old neurons that the new mask can merge to
                possible_masks1 = Masks_cons_2D[:, cents_t2].nonzero()[0]
                IOUs = np.zeros(len(possible_masks1))
                areas_t2 = areas_t[n2]
                for (ind, n1) in enumerate(possible_masks1):
                    # Calculate IoU and consume ratio to determine merged neurons
                    area_i = Masks_cons[n1].multiply(masks_t2).nnz
                    area_temp1 = area_temp[n1]
                    area_u = area_temp1 + areas_t2 - area_i
                    IOU = area_i / area_u
                    consume = area_i / min(area_temp1, areas_t2)
                    contain = (IOU >= thresh_IOU) or (consume >=
                                                      thresh_consume)
                    if contain:  # merging criterion satisfied
                        IOUs[ind] = IOU
                num_contains = IOUs.nonzero()[0].size
                if num_contains:  # The new mask can merge to one of the old neurons.
                    # If there are multiple candicates, choose the one with the highest IoU
                    belongs = possible_masks1[IOUs.argmax()]
                    # merge the mask and active frame index
                    list_masks_old[belongs].append(masks_t2)
                    times_active_old[belongs].append(t + frames_initf)
                    # This old neurons is active in the current frame
                    active_old[belongs] = True
                else:  # The new mask can not merge to any old neuron.
                    new_found[n2] = True

            if np.any(
                    new_found
            ):  # There are some new masks that can not merge to old neurons
                segs_new = (masks_t[new_found], neuronstate_t[new_found],
                            cents_t[new_found], areas_t[new_found])
            else:  # All masks already merged to old neurons
                segs_new = segs_empty

        else:  # No neurons fould in the current frame
            segs_new = segs
        list_segs_new.append(segs_new)

        if (t + 1 - t_merge) != merge_every or t == (nframesf - 1):
            # Update the old neurons with new appearances in the current frame.
            if t < (nframesf - 1):
                # True if the neurons are active in the previous frame but not active in the current frame
                inactive = np.logical_and(
                    active_old_previous,
                    np.logical_not(active_old)).nonzero()[0]
            else:  # last frame
                # All active neurons should be updated, so they are treated as inactive in the next frame
                inactive = active_old_previous.nonzero()[0]

            # Update the indicators of the previous frame using the current frame
            active_old_previous = active_old.copy()
            for n1 in inactive:
                # merge new active frames to existing active frames for already found neurons
                # n1 is the index in the old neurons that satisfy consecutive frame requirement.
                # n10 is the index in all old neurons.
                n10 = ind_cons[n1]
                # Add all the new masks to the overall real-number masks
                mask_update = masks_temp[n10] + sum(list_masks_old[n1])
                masks_temp[n10] = mask_update
                # Add indices of active frames
                times_add = np.unique(np.array(times_active_old[n1]))
                times_temp[n10] = np.hstack([times_temp[n10], times_add])
                # reset lists used to store the information from new frames related to old neurons
                list_masks_old[n1] = []
                times_active_old[n1] = []
                # update the binary masks and areas
                Maskb_update = mask_update >= mask_update.max() * thresh_mask
                Masksb_temp[n10] = Maskb_update
                Masks_cons[n1] = Maskb_update
                area_temp[n10] = Maskb_update.nnz
            if inactive.size:
                Masks_cons_2D = sparse.vstack(Masks_cons)

        if (t + 1 - t_merge) == merge_every or t == (nframesf - 1):
            if t < (nframesf - 1):
                # delay merging new frame to next frame by assuming all the neurons active in the previous frame
                # are still active in the current frame, to reserve merging time for new neurons
                active_old_previous = np.logical_or(active_old_previous,
                                                    active_old)

            # merge new neurons with old masks that do not satisfy consecutive frame requirement
            tuple_temp = (Masksb_temp, masks_temp, times_temp, area_temp,
                          have_cons_temp)
            # merge the remaining new masks from the most recent "merge_every" frames
            tuple_add = merge_complete(list_segs_new, dims, Params_post)
            (Masksb_add, masks_add, times_add, area_add,
             have_cons_add) = tuple_add
            # update the indices of active frames
            times_add = [x + merge_every for x in times_add]
            tuple_add = (Masksb_add, masks_add, times_add, area_add,
                         have_cons_add)
            # merge the remaining new masks with the existing masks that do not satisfy consecutive frame requirement
            tuple_temp = merge_2_nocons(tuple_temp, tuple_add, dims,
                                        Params_post)

            (Masksb_temp, masks_temp, times_temp, area_temp,
             have_cons_temp) = tuple_temp
            # Update the indices of old neurons that satisfy consecutive frame requirement
            ind_cons_new = have_cons_temp.nonzero()[0]
            for (ind, ind_cons_0) in enumerate(ind_cons_new):
                if ind_cons_0 not in ind_cons:
                    # update lists used to store the information from new frames related to old neurons
                    if ind_cons_0 > ind_cons.max():
                        list_masks_old.append([])
                        times_active_old.append([])
                    else:
                        list_masks_old.insert(ind, [])
                        times_active_old.insert(ind, [])

            # Update the list of previously found neurons that satisfy consecutive frame requirement
            Masks_cons = select_cons(tuple_temp)
            Masks_cons_2D = sparse.vstack(Masks_cons)
            N1 = len(Masks_cons)
            list_segs_new = []
            # Update whether the old neurons are active in the previous frame
            active_old_previous = np.zeros_like(have_cons_temp)
            active_old_previous[ind_cons] = active_old
            active_old_previous = active_old_previous[ind_cons_new]
            ind_cons = ind_cons_new
            t_merge += merge_every

        current_frame += 1
        # Update the stored latest frames when it runs out: move them "leng_tf" ahead
        if current_frame > leng_past:
            current_frame = leng_tf + 1
            past_frames[:leng_tf] = past_frames[-leng_tf:]
        if display:
            end_frame = time.time()
            list_time_per[t] = end_frame - start_frame
        if t % 1000 == 0:
            print('{} frames has been processed'.format(t))

    Masks_cons = select_cons(tuple_temp)
    # final result. Masks_2 is a 2D sparse matrix of the segmented neurons
    if len(Masks_cons):
        Masks_2 = sparse.vstack(Masks_cons)
    else:
        Masks_2 = sparse.csc_matrix((0, dims[0] * dims[1]))

    if display:
        end_online = time.time()
        time_online = end_online - start_online
        time_frame_online = time_online / (nframesf - frames_initf) * 1000
        print('Online time: {:6f} s, {:6f} ms/frame'.format(
            time_online, time_frame_online))

    # Save total processing time, and average processing time per frame
    if display:
        end_final = time.time()
        time_all = end_final - start_init
        time_frame_all = time_all / nframes * 1000
        print('Total time: {:6f} s, {:6f} ms/frame'.format(
            time_all, time_frame_all))
        time_total = np.array([time_init, time_online, time_all])
        time_frame = np.array(
            [time_frame_init, time_frame_online, time_frame_all])
    else:
        time_total = np.zeros((3, ))
        time_frame = np.zeros((3, ))

    # convert to a 3D array of the segmented neurons
    Masks = np.reshape(Masks_2.toarray(),
                       (Masks_2.shape[0], Lx, Ly)).astype('bool')
    return Masks, Masks_2, time_total, time_frame
コード例 #2
0
def suns_online(filename_video, filename_CNN, Params_pre, Params_post, dims, \
        frames_init, merge_every, batch_size_init=1, useSF=True, useTF=True, useSNR=True, \
        med_subtract=False, update_baseline=False, \
        useWT=False, show_intermediate=True, prealloc=True, display=True, useMP=True, p=None):
    '''The complete SUNS online procedure.
        It uses the trained CNN model from "filename_CNN" and the optimized hyper-parameters in "Params_post"
        to process the video "Exp_ID" in "dir_video"

    Inputs: 
        filename_video (str): The path of the file of the input raw video.
            The file must be a ".h5" file, with dataset "mov" being the input video (shape = (T0,Lx0,Ly0)).
        filename_CNN (str): The path of the trained CNN model. 
        Params_pre (dict): Parameters for pre-processing.
            Params_pre['gauss_filt_size'] (float): The standard deviation of the spatial Gaussian filter in pixels
            Params_pre['Poisson_filt'] (1D numpy.ndarray of float32): The temporal filter kernel
            Params_pre['num_median_approx'] (int): Number of frames used to compute 
                the median and median-based standard deviation
            Params_pre['nn'] (int): Number of frames at the beginning of the video to be processed.
                The remaining video is not considered a part of the input video.
        Params_post (dict): Parameters for post-processing.
            Params_post['minArea']: Minimum area of a valid neuron mask (unit: pixels).
            Params_post['avgArea']: The typical neuron area (unit: pixels).
            Params_post['thresh_pmap']: The probablity threshold. Values higher than thresh_pmap are active pixels. 
                It is stored in uint8, so it should be converted to float32 before using.
            Params_post['thresh_mask']: Threashold to binarize the real-number mask.
            Params_post['thresh_COM0']: Threshold of COM distance (unit: pixels) used for the first COM-based merging. 
            Params_post['thresh_COM']: Threshold of COM distance (unit: pixels) used for the second COM-based merging. 
            Params_post['thresh_IOU']: Threshold of IOU used for merging neurons.
            Params_post['thresh_consume']: Threshold of consume ratio used for merging neurons.
            Params_post['cons']: Minimum number of consecutive frames that a neuron should be active for.
        dims (tuplel of int, shape = (2,)): lateral dimension of the raw video.
        frames_init (int): Number of frames used for initialization.
        merge_every (int): SUNS online merge the newly segmented frames every "merge_every" frames.
        batch_size_init (int, default to 1): batch size of CNN inference for initialization frames.
        useSF (bool, default to True): True if spatial filtering is used.
        useTF (bool, default to True): True if temporal filtering is used.
        useSNR (bool, default to True): True if pixel-by-pixel SNR normalization filtering is used.
        med_subtract (bool, default to False): True if the spatial median of every frame is subtracted before temporal filtering.
            Can only be used when spatial filtering is not used. 
        update_baseline (bool, default to False): True if the median and median-based std is updated every "frames_init" frames.
        useWT (bool, default to False): Indicator of whether watershed is used. 
        show_intermediate (bool, default to True): Indicator of whether 
            consecutive frame requirement is applied to screen neurons after every update. 
        prealloc (bool, default to True): True if pre-allocate memory space for large variables. 
            Achieve faster speed at the cost of higher memory occupation.
        display (bool, default to True): Indicator of whether to show intermediate information
        useMP (bool, defaut to True): indicator of whether multiprocessing is used to speed up. 
        p (multiprocessing.Pool, default to None): 

    Outputs:
        Masks (3D numpy.ndarray of bool, shape = (n,Lx0,Ly0)): the final segmented masks. 
        Masks_2 (scipy.csr_matrix of bool, shape = (n,Lx0*Ly0)): the final segmented masks in the form of sparse matrix. 
        time_total (list of float, shape = (3,)): the total time spent 
            for initalization, online processing, and total processing
        time_frame (list of float, shape = (3,)): the average time spent on every frame
            for initalization, online processing, and total processing
    '''
    if display:
        start = time.time()
    (Lx, Ly) = dims
    # zero-pad the lateral dimensions to multiples of 8, suitable for CNN
    rowspad = math.ceil(Lx / 8) * 8
    colspad = math.ceil(Ly / 8) * 8
    dimspad = (rowspad, colspad)

    Poisson_filt = Params_pre['Poisson_filt']
    gauss_filt_size = Params_pre['gauss_filt_size']
    nn = Params_pre['nn']
    leng_tf = Poisson_filt.size
    leng_past = 2 * leng_tf  # number of past frames stored for temporal filtering
    list_time_per = np.zeros(nn)

    # Load CNN model
    fff = get_shallow_unet()
    fff.load_weights(filename_CNN)
    # run CNN inference once to warm up
    init_imgs = np.zeros((batch_size_init, rowspad, colspad, 1),
                         dtype='float32')
    init_masks = np.zeros((batch_size_init, rowspad, colspad, 1),
                          dtype='uint8')
    fff.evaluate(init_imgs, init_masks, batch_size=batch_size_init)
    del init_imgs, init_masks

    # load optimal post-processing parameters
    minArea = Params_post['minArea']
    avgArea = Params_post['avgArea']
    # thresh_pmap = Params_post['thresh_pmap']
    thresh_mask = Params_post['thresh_mask']
    thresh_COM0 = Params_post['thresh_COM0']
    thresh_COM = Params_post['thresh_COM']
    thresh_IOU = Params_post['thresh_IOU']
    thresh_consume = Params_post['thresh_consume']
    cons = Params_post['cons']
    # thresh_pmap_float = (Params_post['thresh_pmap']+1.5)/256
    thresh_pmap_float = (Params_post['thresh_pmap'] +
                         1) / 256  # for published version

    # Spatial filtering preparation
    if useSF == True:
        # lateral dimensions slightly larger than the raw video but faster for FFT
        rows1 = cv2.getOptimalDFTSize(rowspad)
        cols1 = cv2.getOptimalDFTSize(colspad)

        # if the learned 2D and 3D wisdom files have been saved, load them.
        # Otherwise, learn wisdom later
        Length_data2 = str((rows1, cols1))
        cc2 = load_wisdom_txt('wisdom\\' + Length_data2)

        Length_data3 = str((frames_init, rows1, cols1))
        cc3 = load_wisdom_txt('wisdom\\' + Length_data3)
        if cc3:
            pyfftw.import_wisdom(cc3)

        # mask for spatial filter
        mask2 = plan_mask2(dims, (rows1, cols1), gauss_filt_size)
        # FFT planning
        (bb, bf, fft_object_b,
         fft_object_c) = plan_fft(frames_init, (rows1, cols1), prealloc)
    else:
        (mask2, bf, fft_object_b, fft_object_c) = (None, None, None, None)
        bb = np.zeros((frames_init, rowspad, colspad), dtype='float32')

    # Temporal filtering preparation
    frames_initf = frames_init - leng_tf + 1
    if useTF == True:
        if prealloc:
            # past frames stored for temporal filtering
            past_frames = np.ones((leng_past, rowspad, colspad),
                                  dtype='float32')
        else:
            past_frames = np.zeros((leng_past, rowspad, colspad),
                                   dtype='float32')
    else:
        past_frames = None

    if prealloc:  # Pre-allocate memory for some future variables
        med_frame2 = np.ones((rowspad, colspad, 2), dtype='float32')
        video_input = np.ones((frames_initf, rowspad, colspad),
                              dtype='float32')
        pmaps_b_init = np.ones((frames_initf, Lx, Ly), dtype='uint8')
        frame_SNR = np.ones(dimspad, dtype='float32')
        pmaps_b = np.ones(dims, dtype='uint8')
        if update_baseline:
            video_tf_past = np.ones((frames_init, rowspad, colspad),
                                    dtype='float32')
    else:
        med_frame2 = np.zeros((rowspad, colspad, 2), dtype='float32')
        video_input = np.zeros((frames_initf, rowspad, colspad),
                               dtype='float32')
        pmaps_b_init = np.zeros((frames_initf, Lx, Ly), dtype='uint8')
        frame_SNR = np.zeros(dimspad, dtype='float32')
        pmaps_b = np.zeros(dims, dtype='uint8')
        if update_baseline:
            video_tf_past = np.zeros((frames_init, rowspad, colspad),
                                     dtype='float32')

    if display:
        time_init = time.time()
        print('Parameter initialization time: {} s'.format(time_init - start))

    # %% Load raw video
    h5_img = h5py.File(filename_video, 'r')
    video_raw = np.array(h5_img['mov'])
    h5_img.close()
    nframes = video_raw.shape[0]
    nframesf = nframes - leng_tf + 1
    bb[:, :Lx, :Ly] = video_raw[:frames_init]
    if display:
        time_load = time.time()
        print('Load data: {} s'.format(time_load - time_init))

    # %% Actual processing starts after the video is loaded into memory
    # Initialization using the first "frames_init" frames
    print('Initialization of algorithms using the first {} frames'.format(
        frames_init))
    if display:
        start_init = time.time()
    med_frame3, segs_all, recent_frames = init_online(
        bb, dims, video_input, pmaps_b_init, fff, thresh_pmap_float, Params_post, \
        med_frame2, mask2, bf, fft_object_b, fft_object_c, Poisson_filt, \
        useSF=useSF, useTF=useTF, useSNR=useSNR, med_subtract=med_subtract, \
        useWT=useWT, batch_size_init=batch_size_init, p=p)
    if useTF == True:
        past_frames[:leng_tf] = recent_frames
    tuple_temp = merge_complete(segs_all[:frames_initf], dims, Params_post)
    if show_intermediate:
        Masks_2 = select_cons(tuple_temp)

    if display:
        end_init = time.time()
        time_init = end_init - start_init
        time_frame_init = time_init / (frames_initf) * 1000
        print('Initialization time: {:6f} s, {:6f} ms/frame'.format(
            time_init, time_frame_init))

    if display:
        start_online = time.time()
    # Spatial filtering preparation for online processing.
    # Attention: this part counts to the total time
    if useSF:
        if cc2:
            pyfftw.import_wisdom(cc2)
        (bb, bf, fft_object_b, fft_object_c) = plan_fft2((rows1, cols1))
    else:
        (bf, fft_object_b, fft_object_c) = (None, None, None)
        bb = np.zeros(dimspad, dtype='float32')

    print('Start frame by frame processing')
    # %% Online processing for the following frames
    current_frame = leng_tf + 1
    t_merge = frames_initf
    for t in range(frames_initf, nframesf):
        if display:
            start_frame = time.time()
        # load the current frame
        bb[:Lx, :Ly] = video_raw[t + leng_tf - 1]
        bb[Lx:, :] = 0
        bb[:, Ly:] = 0

        # PreProcessing
        frame_SNR, frame_tf = preprocess_online(bb, dimspad, med_frame3, frame_SNR, \
            past_frames[current_frame-leng_tf:current_frame], mask2, bf, fft_object_b, \
            fft_object_c, Poisson_filt, useSF=useSF, useTF=useTF, useSNR=useSNR, \
            med_subtract=med_subtract, update_baseline=update_baseline)

        if update_baseline:
            t_past = (t - frames_initf) % frames_init
            video_tf_past[t_past] = frame_tf
            if t_past == frames_init - 1:
                # update median and median-based standard deviation every "frames_init" frames
                if useSNR:
                    med_frame3 = SNR_normalization(video_tf_past,
                                                   med_frame2,
                                                   (rowspad, colspad),
                                                   1,
                                                   display=False)
                else:
                    med_frame3 = median_normalization(video_tf_past,
                                                      med_frame2,
                                                      (rowspad, colspad),
                                                      1,
                                                      display=False)

        # CNN inference
        frame_prob = CNN_online(frame_SNR, fff, dims)

        # first step of post-processing
        segs = separate_neuron_online(frame_prob, pmaps_b, thresh_pmap_float,
                                      minArea, avgArea, useWT)
        segs_all.append(segs)

        # temporal merging 1: combine neurons with COM distance smaller than thresh_COM0
        if ((t + 1 - t_merge) == merge_every) or (t == nframesf - 1):
            # uniques, times_uniques = unique_neurons1_simp(segs_all[t_merge:], thresh_COM0) # minArea,
            totalmasks, neuronstate, COMs, areas, probmapID = segs_results(
                segs_all[t_merge:])
            uniques, times_uniques = unique_neurons2_simp(totalmasks, neuronstate, COMs, \
                areas, probmapID, minArea=0, thresh_COM0=thresh_COM0, useMP=useMP)

        # temporal merging 2: combine neurons with COM distance smaller than thresh_COM
        if ((t - 0 - t_merge) == merge_every) or (t == nframesf - 1):
            if uniques.size:
                groupedneurons, times_groupedneurons = \
                    group_neurons(uniques, thresh_COM, thresh_mask, dims, times_uniques, useMP=useMP)

        # temporal merging 3: combine neurons with IoU larger than thresh_IOU
        if ((t - 1 - t_merge) == merge_every) or (t == nframesf - 1):
            if uniques.size:
                piecedneurons_1, times_piecedneurons_1 = \
                    piece_neurons_IOU(groupedneurons, thresh_mask, thresh_IOU, times_groupedneurons)

        # temporal merging 4: combine neurons with conumse ratio larger than thresh_consume
        if ((t - 2 - t_merge) == merge_every) or (t == nframesf - 1):
            if uniques.size:
                piecedneurons, times_piecedneurons = \
                    piece_neurons_consume(piecedneurons_1, avgArea, thresh_mask, thresh_consume, times_piecedneurons_1)
                # masks of new neurons
                masks_add = piecedneurons
                # indices of frames when the neurons are active
                times_add = [
                    np.unique(x) + t_merge for x in times_piecedneurons
                ]

                # Refine neurons using consecutive occurence
                if masks_add.size:
                    # new real-number masks
                    masks_add = [x for x in masks_add]
                    # new binary masks
                    Masksb_add = [(x >= x.max() * thresh_mask).astype('float')
                                  for x in masks_add]
                    # areas of new masks
                    area_add = np.array([x.nnz for x in Masksb_add])
                    # indicators of whether the new masks satisfy consecutive frame requirement
                    have_cons_add = refine_seperate_cons_online(
                        times_add, cons)
                else:
                    Masksb_add = []
                    area_add = np.array([])
                    have_cons_add = np.array([])

            else:  # does not find any active neuron
                Masksb_add = []
                masks_add = []
                times_add = times_uniques
                area_add = np.array([])
                have_cons_add = np.array([])
            tuple_add = (Masksb_add, masks_add, times_add, area_add,
                         have_cons_add)

        # temporal merging 5: merge newly found neurons within the recent "merge_every" frames with existing neurons
        if ((t - 3 - t_merge) == merge_every) or (t == nframesf - 1):
            tuple_temp = merge_2(tuple_temp, tuple_add, dims, Params_post)
            t_merge += merge_every
            if show_intermediate:
                Masks_2 = select_cons(tuple_temp)

        current_frame += 1
        # Update the stored latest frames when it runs out: move them "leng_tf" ahead
        if current_frame > leng_past:
            current_frame = leng_tf + 1
            past_frames[:leng_tf] = past_frames[-leng_tf:]
        if display:
            end_frame = time.time()
            list_time_per[t] = end_frame - start_frame
        if t % 1000 == 0:
            print('{} frames has been processed'.format(t))

    if not show_intermediate:
        Masks_2 = select_cons(tuple_temp)
    # final result. Masks_2 is a 2D sparse matrix of the segmented neurons
    if len(Masks_2):
        Masks_2 = sparse.vstack(Masks_2)
    else:
        Masks_2 = sparse.csc_matrix((0, dims[0] * dims[1]))

    if display:
        end_online = time.time()
        time_online = end_online - start_online
        time_frame_online = time_online / (nframesf - frames_initf) * 1000
        print('Online time: {:6f} s, {:6f} ms/frame'.format(
            time_online, time_frame_online))

    # Save total processing time, and average processing time per frame
    if display:
        end_final = time.time()
        time_all = end_final - start_init
        time_frame_all = time_all / nframes * 1000
        print('Total time: {:6f} s, {:6f} ms/frame'.format(
            time_all, time_frame_all))
        time_total = np.array([time_init, time_online, time_all])
        time_frame = np.array(
            [time_frame_init, time_frame_online, time_frame_all])
    else:
        time_total = np.zeros((3, ))
        time_frame = np.zeros((3, ))

    # convert to a 3D array of the segmented neurons
    Masks = np.reshape(Masks_2.toarray(),
                       (Masks_2.shape[0], Lx, Ly)).astype('bool')
    return Masks, Masks_2, time_total, time_frame