示例#1
0
def optimize_combine_1(uniques: sparse.csr_matrix, times_uniques: list,
                       dims: tuple, Params: dict, filename_GT: str):
    '''Optimize 1 post-processing parameter: "cons". 
        Start after the first COM merging.
        The outputs are the recall, precision, and F1 calculated using all values in "list_cons".

    Inputs: 
        uniques (sparse.csr_matrix of float32, shape = (n,Lx*Ly)): the neuron masks to be merged.
        times_uniques (list of 1D numpy.array): indices of frames when the neuron is active.
        dims (tuple of int, shape = (2,)): the lateral shape of the image.
        Params (dict): Ranges of post-processing parameters to optimize over.
            Params['avgArea']: The typical neuron area (unit: pixels).
            Params['thresh_mask']: (float) Threashold to binarize the real-number mask.
            Params['thresh_COM']: (float or int) Threshold of COM distance (unit: pixels) used for the second COM-based merging. 
            Params['thresh_IOU']: (float) Threshold of IoU used for merging neurons.
            Params['list_cons']: (list) Range of minimum number of consecutive frames that a neuron should be active for.
        filename_GT (str): file name of the GT masks. 
            The GT masks are stored in a ".mat" file, and dataset "GTMasks_2" is the GT masks
            (shape = (Ly0*Lx0,n) when saved in MATLAB).

    Outputs:
        Recall_k (1D numpy.array of float): Recall for all cons. 
        Precision_k (1D numpy.array of float): Precision for all cons. 
        F1_k (1D numpy.array of float): F1 for all cons. 
    '''
    avgArea = Params['avgArea']
    thresh_mask = Params['thresh_mask']
    thresh_COM = Params['thresh_COM']
    thresh_IOU = Params['thresh_IOU']
    thresh_consume = (1 + thresh_IOU) / 2
    list_cons = Params['list_cons']

    # second merge neurons with close COM.
    groupedneurons, times_groupedneurons = group_neurons(uniques, \
        thresh_COM, thresh_mask, (dims[1], dims[2]), times_uniques, useMP=False)
    # Merge neurons with high IoU.
    piecedneurons_1, times_piecedneurons_1 = piece_neurons_IOU(groupedneurons, \
        thresh_mask, thresh_IOU, times_groupedneurons)
    # Merge neurons with high consume ratio.
    piecedneurons, times_piecedneurons = piece_neurons_consume(piecedneurons_1, \
        avgArea, thresh_mask, thresh_consume, times_piecedneurons_1)
    masks_final_2 = piecedneurons
    times_final = [np.unique(x) for x in times_piecedneurons]

    data_GT = loadmat(filename_GT)
    GTMasks_2 = data_GT['GTMasks_2'].transpose()
    # Search for optimal "cons" used to refine segmented neurons.
    Recall_k, Precision_k, F1_k = refine_seperate_multi(GTMasks_2, \
        masks_final_2, times_final, list_cons, thresh_mask, display=False)
    return Recall_k, Precision_k, F1_k
示例#2
0
def final_merge(tuple_temp, Params):
    '''An extra round of merging at the end of online processing,
        to merge all previously detected neurons according their to IoU and consume ratio, like in batch mode.
        The output are "Masks_2b", a 2D sparse matrix of the final segmented neurons,
        and "times_cons", a list of indices of frames when the final neuron is active.

    Inputs: 
        tuple_temp (tuple, shape = (5,)): Segmented masks with statistics.
        Params (dict): Parameters for post-processing.
            Params['thresh_mask']: Threashold to binarize the real-number mask.
            Params['thresh_IOU']: Threshold of IOU used for merging neurons.
            Params['thresh_consume']: Threshold of consume ratio used for merging neurons.
            Params['cons']: Minimum number of consecutive frames that a neuron should be active for.
            Params['avgArea']: The typical neuron area (unit: pixels).

    Outputs:
        Masks_2b (sparse.csr_matrix of bool): the final segmented binary neuron masks after consecutive refinement. 
        times_cons (list of 1D numpy.array): indices of frames when the final neuron is active.
    '''
    _, masks, times, _, _ = tuple_temp
    if len(masks)==0: # If no masks is found, the output is tuple_temp
        return tuple_temp
    # if area.ndim==0:
    #     area = np.expand_dims(area, axis=0)
    thresh_mask = Params['thresh_mask']
    thresh_IOU = Params['thresh_IOU']
    thresh_consume = Params['thresh_consume']
    cons = Params['cons']
    avgArea = Params['avgArea']
    masks = sparse.vstack(masks)

    masks_1, times_1 = piece_neurons_IOU(masks, thresh_mask, thresh_IOU, times)
    masks_final_2, times_2 = piece_neurons_consume(masks_1, avgArea, thresh_mask, thresh_consume, times_1)
    Masks_2b, times_final = refine_seperate(masks_final_2, times_2, cons, thresh_mask)

    return Masks_2b, times_final
示例#3
0
def suns_online(filename_video, filename_CNN, Params_pre, Params_post, dims, \
        frames_init, merge_every, batch_size_init=1, useSF=True, useTF=True, useSNR=True, \
        med_subtract=False, update_baseline=False, \
        useWT=False, show_intermediate=True, prealloc=True, display=True, useMP=True, p=None):
    '''The complete SUNS online procedure.
        It uses the trained CNN model from "filename_CNN" and the optimized hyper-parameters in "Params_post"
        to process the video "Exp_ID" in "dir_video"

    Inputs: 
        filename_video (str): The path of the file of the input raw video.
            The file must be a ".h5" file, with dataset "mov" being the input video (shape = (T0,Lx0,Ly0)).
        filename_CNN (str): The path of the trained CNN model. 
        Params_pre (dict): Parameters for pre-processing.
            Params_pre['gauss_filt_size'] (float): The standard deviation of the spatial Gaussian filter in pixels
            Params_pre['Poisson_filt'] (1D numpy.ndarray of float32): The temporal filter kernel
            Params_pre['num_median_approx'] (int): Number of frames used to compute 
                the median and median-based standard deviation
            Params_pre['nn'] (int): Number of frames at the beginning of the video to be processed.
                The remaining video is not considered a part of the input video.
        Params_post (dict): Parameters for post-processing.
            Params_post['minArea']: Minimum area of a valid neuron mask (unit: pixels).
            Params_post['avgArea']: The typical neuron area (unit: pixels).
            Params_post['thresh_pmap']: The probablity threshold. Values higher than thresh_pmap are active pixels. 
                It is stored in uint8, so it should be converted to float32 before using.
            Params_post['thresh_mask']: Threashold to binarize the real-number mask.
            Params_post['thresh_COM0']: Threshold of COM distance (unit: pixels) used for the first COM-based merging. 
            Params_post['thresh_COM']: Threshold of COM distance (unit: pixels) used for the second COM-based merging. 
            Params_post['thresh_IOU']: Threshold of IOU used for merging neurons.
            Params_post['thresh_consume']: Threshold of consume ratio used for merging neurons.
            Params_post['cons']: Minimum number of consecutive frames that a neuron should be active for.
        dims (tuplel of int, shape = (2,)): lateral dimension of the raw video.
        frames_init (int): Number of frames used for initialization.
        merge_every (int): SUNS online merge the newly segmented frames every "merge_every" frames.
        batch_size_init (int, default to 1): batch size of CNN inference for initialization frames.
        useSF (bool, default to True): True if spatial filtering is used.
        useTF (bool, default to True): True if temporal filtering is used.
        useSNR (bool, default to True): True if pixel-by-pixel SNR normalization filtering is used.
        med_subtract (bool, default to False): True if the spatial median of every frame is subtracted before temporal filtering.
            Can only be used when spatial filtering is not used. 
        update_baseline (bool, default to False): True if the median and median-based std is updated every "frames_init" frames.
        useWT (bool, default to False): Indicator of whether watershed is used. 
        show_intermediate (bool, default to True): Indicator of whether 
            consecutive frame requirement is applied to screen neurons after every update. 
        prealloc (bool, default to True): True if pre-allocate memory space for large variables. 
            Achieve faster speed at the cost of higher memory occupation.
        display (bool, default to True): Indicator of whether to show intermediate information
        useMP (bool, defaut to True): indicator of whether multiprocessing is used to speed up. 
        p (multiprocessing.Pool, default to None): 

    Outputs:
        Masks (3D numpy.ndarray of bool, shape = (n,Lx0,Ly0)): the final segmented masks. 
        Masks_2 (scipy.csr_matrix of bool, shape = (n,Lx0*Ly0)): the final segmented masks in the form of sparse matrix. 
        time_total (list of float, shape = (3,)): the total time spent 
            for initalization, online processing, and total processing
        time_frame (list of float, shape = (3,)): the average time spent on every frame
            for initalization, online processing, and total processing
    '''
    if display:
        start = time.time()
    (Lx, Ly) = dims
    # zero-pad the lateral dimensions to multiples of 8, suitable for CNN
    rowspad = math.ceil(Lx / 8) * 8
    colspad = math.ceil(Ly / 8) * 8
    dimspad = (rowspad, colspad)

    Poisson_filt = Params_pre['Poisson_filt']
    gauss_filt_size = Params_pre['gauss_filt_size']
    nn = Params_pre['nn']
    leng_tf = Poisson_filt.size
    leng_past = 2 * leng_tf  # number of past frames stored for temporal filtering
    list_time_per = np.zeros(nn)

    # Load CNN model
    fff = get_shallow_unet()
    fff.load_weights(filename_CNN)
    # run CNN inference once to warm up
    init_imgs = np.zeros((batch_size_init, rowspad, colspad, 1),
                         dtype='float32')
    init_masks = np.zeros((batch_size_init, rowspad, colspad, 1),
                          dtype='uint8')
    fff.evaluate(init_imgs, init_masks, batch_size=batch_size_init)
    del init_imgs, init_masks

    # load optimal post-processing parameters
    minArea = Params_post['minArea']
    avgArea = Params_post['avgArea']
    # thresh_pmap = Params_post['thresh_pmap']
    thresh_mask = Params_post['thresh_mask']
    thresh_COM0 = Params_post['thresh_COM0']
    thresh_COM = Params_post['thresh_COM']
    thresh_IOU = Params_post['thresh_IOU']
    thresh_consume = Params_post['thresh_consume']
    cons = Params_post['cons']
    # thresh_pmap_float = (Params_post['thresh_pmap']+1.5)/256
    thresh_pmap_float = (Params_post['thresh_pmap'] +
                         1) / 256  # for published version

    # Spatial filtering preparation
    if useSF == True:
        # lateral dimensions slightly larger than the raw video but faster for FFT
        rows1 = cv2.getOptimalDFTSize(rowspad)
        cols1 = cv2.getOptimalDFTSize(colspad)

        # if the learned 2D and 3D wisdom files have been saved, load them.
        # Otherwise, learn wisdom later
        Length_data2 = str((rows1, cols1))
        cc2 = load_wisdom_txt('wisdom\\' + Length_data2)

        Length_data3 = str((frames_init, rows1, cols1))
        cc3 = load_wisdom_txt('wisdom\\' + Length_data3)
        if cc3:
            pyfftw.import_wisdom(cc3)

        # mask for spatial filter
        mask2 = plan_mask2(dims, (rows1, cols1), gauss_filt_size)
        # FFT planning
        (bb, bf, fft_object_b,
         fft_object_c) = plan_fft(frames_init, (rows1, cols1), prealloc)
    else:
        (mask2, bf, fft_object_b, fft_object_c) = (None, None, None, None)
        bb = np.zeros((frames_init, rowspad, colspad), dtype='float32')

    # Temporal filtering preparation
    frames_initf = frames_init - leng_tf + 1
    if useTF == True:
        if prealloc:
            # past frames stored for temporal filtering
            past_frames = np.ones((leng_past, rowspad, colspad),
                                  dtype='float32')
        else:
            past_frames = np.zeros((leng_past, rowspad, colspad),
                                   dtype='float32')
    else:
        past_frames = None

    if prealloc:  # Pre-allocate memory for some future variables
        med_frame2 = np.ones((rowspad, colspad, 2), dtype='float32')
        video_input = np.ones((frames_initf, rowspad, colspad),
                              dtype='float32')
        pmaps_b_init = np.ones((frames_initf, Lx, Ly), dtype='uint8')
        frame_SNR = np.ones(dimspad, dtype='float32')
        pmaps_b = np.ones(dims, dtype='uint8')
        if update_baseline:
            video_tf_past = np.ones((frames_init, rowspad, colspad),
                                    dtype='float32')
    else:
        med_frame2 = np.zeros((rowspad, colspad, 2), dtype='float32')
        video_input = np.zeros((frames_initf, rowspad, colspad),
                               dtype='float32')
        pmaps_b_init = np.zeros((frames_initf, Lx, Ly), dtype='uint8')
        frame_SNR = np.zeros(dimspad, dtype='float32')
        pmaps_b = np.zeros(dims, dtype='uint8')
        if update_baseline:
            video_tf_past = np.zeros((frames_init, rowspad, colspad),
                                     dtype='float32')

    if display:
        time_init = time.time()
        print('Parameter initialization time: {} s'.format(time_init - start))

    # %% Load raw video
    h5_img = h5py.File(filename_video, 'r')
    video_raw = np.array(h5_img['mov'])
    h5_img.close()
    nframes = video_raw.shape[0]
    nframesf = nframes - leng_tf + 1
    bb[:, :Lx, :Ly] = video_raw[:frames_init]
    if display:
        time_load = time.time()
        print('Load data: {} s'.format(time_load - time_init))

    # %% Actual processing starts after the video is loaded into memory
    # Initialization using the first "frames_init" frames
    print('Initialization of algorithms using the first {} frames'.format(
        frames_init))
    if display:
        start_init = time.time()
    med_frame3, segs_all, recent_frames = init_online(
        bb, dims, video_input, pmaps_b_init, fff, thresh_pmap_float, Params_post, \
        med_frame2, mask2, bf, fft_object_b, fft_object_c, Poisson_filt, \
        useSF=useSF, useTF=useTF, useSNR=useSNR, med_subtract=med_subtract, \
        useWT=useWT, batch_size_init=batch_size_init, p=p)
    if useTF == True:
        past_frames[:leng_tf] = recent_frames
    tuple_temp = merge_complete(segs_all[:frames_initf], dims, Params_post)
    if show_intermediate:
        Masks_2 = select_cons(tuple_temp)

    if display:
        end_init = time.time()
        time_init = end_init - start_init
        time_frame_init = time_init / (frames_initf) * 1000
        print('Initialization time: {:6f} s, {:6f} ms/frame'.format(
            time_init, time_frame_init))

    if display:
        start_online = time.time()
    # Spatial filtering preparation for online processing.
    # Attention: this part counts to the total time
    if useSF:
        if cc2:
            pyfftw.import_wisdom(cc2)
        (bb, bf, fft_object_b, fft_object_c) = plan_fft2((rows1, cols1))
    else:
        (bf, fft_object_b, fft_object_c) = (None, None, None)
        bb = np.zeros(dimspad, dtype='float32')

    print('Start frame by frame processing')
    # %% Online processing for the following frames
    current_frame = leng_tf + 1
    t_merge = frames_initf
    for t in range(frames_initf, nframesf):
        if display:
            start_frame = time.time()
        # load the current frame
        bb[:Lx, :Ly] = video_raw[t + leng_tf - 1]
        bb[Lx:, :] = 0
        bb[:, Ly:] = 0

        # PreProcessing
        frame_SNR, frame_tf = preprocess_online(bb, dimspad, med_frame3, frame_SNR, \
            past_frames[current_frame-leng_tf:current_frame], mask2, bf, fft_object_b, \
            fft_object_c, Poisson_filt, useSF=useSF, useTF=useTF, useSNR=useSNR, \
            med_subtract=med_subtract, update_baseline=update_baseline)

        if update_baseline:
            t_past = (t - frames_initf) % frames_init
            video_tf_past[t_past] = frame_tf
            if t_past == frames_init - 1:
                # update median and median-based standard deviation every "frames_init" frames
                if useSNR:
                    med_frame3 = SNR_normalization(video_tf_past,
                                                   med_frame2,
                                                   (rowspad, colspad),
                                                   1,
                                                   display=False)
                else:
                    med_frame3 = median_normalization(video_tf_past,
                                                      med_frame2,
                                                      (rowspad, colspad),
                                                      1,
                                                      display=False)

        # CNN inference
        frame_prob = CNN_online(frame_SNR, fff, dims)

        # first step of post-processing
        segs = separate_neuron_online(frame_prob, pmaps_b, thresh_pmap_float,
                                      minArea, avgArea, useWT)
        segs_all.append(segs)

        # temporal merging 1: combine neurons with COM distance smaller than thresh_COM0
        if ((t + 1 - t_merge) == merge_every) or (t == nframesf - 1):
            # uniques, times_uniques = unique_neurons1_simp(segs_all[t_merge:], thresh_COM0) # minArea,
            totalmasks, neuronstate, COMs, areas, probmapID = segs_results(
                segs_all[t_merge:])
            uniques, times_uniques = unique_neurons2_simp(totalmasks, neuronstate, COMs, \
                areas, probmapID, minArea=0, thresh_COM0=thresh_COM0, useMP=useMP)

        # temporal merging 2: combine neurons with COM distance smaller than thresh_COM
        if ((t - 0 - t_merge) == merge_every) or (t == nframesf - 1):
            if uniques.size:
                groupedneurons, times_groupedneurons = \
                    group_neurons(uniques, thresh_COM, thresh_mask, dims, times_uniques, useMP=useMP)

        # temporal merging 3: combine neurons with IoU larger than thresh_IOU
        if ((t - 1 - t_merge) == merge_every) or (t == nframesf - 1):
            if uniques.size:
                piecedneurons_1, times_piecedneurons_1 = \
                    piece_neurons_IOU(groupedneurons, thresh_mask, thresh_IOU, times_groupedneurons)

        # temporal merging 4: combine neurons with conumse ratio larger than thresh_consume
        if ((t - 2 - t_merge) == merge_every) or (t == nframesf - 1):
            if uniques.size:
                piecedneurons, times_piecedneurons = \
                    piece_neurons_consume(piecedneurons_1, avgArea, thresh_mask, thresh_consume, times_piecedneurons_1)
                # masks of new neurons
                masks_add = piecedneurons
                # indices of frames when the neurons are active
                times_add = [
                    np.unique(x) + t_merge for x in times_piecedneurons
                ]

                # Refine neurons using consecutive occurence
                if masks_add.size:
                    # new real-number masks
                    masks_add = [x for x in masks_add]
                    # new binary masks
                    Masksb_add = [(x >= x.max() * thresh_mask).astype('float')
                                  for x in masks_add]
                    # areas of new masks
                    area_add = np.array([x.nnz for x in Masksb_add])
                    # indicators of whether the new masks satisfy consecutive frame requirement
                    have_cons_add = refine_seperate_cons_online(
                        times_add, cons)
                else:
                    Masksb_add = []
                    area_add = np.array([])
                    have_cons_add = np.array([])

            else:  # does not find any active neuron
                Masksb_add = []
                masks_add = []
                times_add = times_uniques
                area_add = np.array([])
                have_cons_add = np.array([])
            tuple_add = (Masksb_add, masks_add, times_add, area_add,
                         have_cons_add)

        # temporal merging 5: merge newly found neurons within the recent "merge_every" frames with existing neurons
        if ((t - 3 - t_merge) == merge_every) or (t == nframesf - 1):
            tuple_temp = merge_2(tuple_temp, tuple_add, dims, Params_post)
            t_merge += merge_every
            if show_intermediate:
                Masks_2 = select_cons(tuple_temp)

        current_frame += 1
        # Update the stored latest frames when it runs out: move them "leng_tf" ahead
        if current_frame > leng_past:
            current_frame = leng_tf + 1
            past_frames[:leng_tf] = past_frames[-leng_tf:]
        if display:
            end_frame = time.time()
            list_time_per[t] = end_frame - start_frame
        if t % 1000 == 0:
            print('{} frames has been processed'.format(t))

    if not show_intermediate:
        Masks_2 = select_cons(tuple_temp)
    # final result. Masks_2 is a 2D sparse matrix of the segmented neurons
    if len(Masks_2):
        Masks_2 = sparse.vstack(Masks_2)
    else:
        Masks_2 = sparse.csc_matrix((0, dims[0] * dims[1]))

    if display:
        end_online = time.time()
        time_online = end_online - start_online
        time_frame_online = time_online / (nframesf - frames_initf) * 1000
        print('Online time: {:6f} s, {:6f} ms/frame'.format(
            time_online, time_frame_online))

    # Save total processing time, and average processing time per frame
    if display:
        end_final = time.time()
        time_all = end_final - start_init
        time_frame_all = time_all / nframes * 1000
        print('Total time: {:6f} s, {:6f} ms/frame'.format(
            time_all, time_frame_all))
        time_total = np.array([time_init, time_online, time_all])
        time_frame = np.array(
            [time_frame_init, time_frame_online, time_frame_all])
    else:
        time_total = np.zeros((3, ))
        time_frame = np.zeros((3, ))

    # convert to a 3D array of the segmented neurons
    Masks = np.reshape(Masks_2.toarray(),
                       (Masks_2.shape[0], Lx, Ly)).astype('bool')
    return Masks, Masks_2, time_total, time_frame
示例#4
0
def merge_complete(segs, dims, Params):
    '''Temporally merge segmented masks in a few frames.
        The output are the merged neuron masks and their statistics 
        (acitve frame indices, areas, whether satisfy consecutive activation).

    Inputs: 
        segs (list): A list of segmented masks for every frame with statistics.
        dims (tuple of int, shape = (2,)): lateral dimension of the image.
        Params_post (dict): Parameters for post-processing.
            Params['thresh_mask']: Threashold to binarize the real-number mask.
            Params['thresh_COM0']: Threshold of COM distance (unit: pixels) used for the first COM-based merging. 
            Params['thresh_COM']: Threshold of COM distance (unit: pixels) used for the second COM-based merging. 
            Params['thresh_IOU']: Threshold of IOU used for merging neurons.
            Params['thresh_consume']: Threshold of consume ratio used for merging neurons.
            Params['cons']: Minimum number of consecutive frames that a neuron should be active for.

    Outputs:
        Masks_2 (list of sparse.csr_matrix of bool, shape = (1,Lx*Ly)): 
            2D representation of each segmented binary mask.
        masks_final_2 (list of sparse.csr_matrix of float32, shape = (1,Lx*Ly)): 
            2D representation of each segmented real-number mask.
        times_final (list of 1D numpy.ndarray of int): 
            indices of frames when each neuron is active.
        area (1D numpy.ndarray of float32): areas of each mask.
        have_cons (1D numpy.ndarray of bool): 
            indices of whether each neuron satisfy consecutive frame requirement.
        The above outputs are often grouped into a tuple (shape = (5,)): 
            Segmented masks with statistics after update.
    '''
    avgArea = Params['avgArea']
    thresh_mask = Params['thresh_mask']
    thresh_COM0 = Params['thresh_COM0']
    thresh_COM = Params['thresh_COM']
    thresh_IOU = Params['thresh_IOU']
    thresh_consume = Params['thresh_consume']
    cons = Params['cons']

    totalmasks, neuronstate, COMs, areas, probmapID = segs_results(segs)
    # Initally merge neurons with close COM.
    uniques, times_uniques = unique_neurons2_simp(totalmasks, neuronstate, COMs, \
        areas, probmapID, minArea=0, thresh_COM0=thresh_COM0)
    if uniques.size:
        # Further merge neurons with close COM.
        groupedneurons, times_groupedneurons = \
            group_neurons(uniques, thresh_COM, thresh_mask, dims, times_uniques)
        # Merge neurons with high IoU.
        piecedneurons_1, times_piecedneurons_1 = \
            piece_neurons_IOU(groupedneurons, thresh_mask, thresh_IOU, times_groupedneurons)
        # Merge neurons with high consume ratio.
        piecedneurons, times_piecedneurons = \
            piece_neurons_consume(piecedneurons_1, avgArea, thresh_mask, thresh_consume, times_piecedneurons_1)
        # %% Final result
        masks_final_2 = piecedneurons
        times_final = [np.unique(x) for x in times_piecedneurons]
            
        # %% Refine neurons using consecutive occurence
        if masks_final_2.size:
            masks_final_2 = [x for x in masks_final_2]
            Masks_2 = [(x >= x.max() * thresh_mask).astype('float') for x in masks_final_2]
            area = np.array([x.nnz for x in Masks_2])
            have_cons = refine_seperate_cons_online(times_final, cons)
        else:
            Masks_2 = []
            area = np.array([])
            have_cons = np.array([])

    else:
        Masks_2 = []
        masks_final_2 = []
        times_final = times_uniques
        area = np.array([])
        have_cons = np.array([])

    return Masks_2, masks_final_2, times_final, area, have_cons 
示例#5
0
def merge_complete_nocons(uniques, times_uniques, dims, Params):
    '''Temporally merge segmented masks in a few frames. Used for parameter optimization.
        Ignore consecutive frame requirement in this function.
        The output are the merged neuron masks and their statistics 
        (acitve frame indices, areas, whether satisfy consecutive activation).

    Inputs: 
        uniques (sparse.csr_matrix): the neuron masks after the first COM merging. 
        times_uniques (list of 1D numpy.array): indices of frames when the neuron is active.
        dims (tuple of int, shape = (2,)): lateral dimension of the image.
        Params_post (dict): Parameters for post-processing.
            Params['thresh_mask']: Threashold to binarize the real-number mask.
            Params['thresh_COM0']: Threshold of COM distance (unit: pixels) used for the first COM-based merging. 
            Params['thresh_COM']: Threshold of COM distance (unit: pixels) used for the second COM-based merging. 
            Params['thresh_IOU']: Threshold of IOU used for merging neurons.
            Params['thresh_consume']: Threshold of consume ratio used for merging neurons.

    Outputs:
        Masks_2 (list of sparse.csr_matrix of bool, shape = (1,Lx*Ly)): 
            2D representation of each segmented binary mask.
        masks_final_2 (list of sparse.csr_matrix of float32, shape = (1,Lx*Ly)): 
            2D representation of each segmented real-number mask.
        times_final (list of 1D numpy.ndarray of int): 
            indices of frames when each neuron is active.
        area (1D numpy.ndarray of float32): areas of each mask.
        have_cons (1D numpy.ndarray of bool): 
            indices of whether each neuron satisfy consecutive frame requirement.
        The above outputs are often grouped into a tuple (shape = (5,)): 
            Segmented masks with statistics after update.
    '''
    avgArea = Params['avgArea']
    thresh_mask = Params['thresh_mask']
    thresh_COM = Params['thresh_COM']
    thresh_IOU = Params['thresh_IOU']
    thresh_consume = Params['thresh_consume']

    if uniques.size:
        # Further merge neurons with close COM.
        groupedneurons, times_groupedneurons = \
            group_neurons(uniques, thresh_COM, thresh_mask, dims, times_uniques)
        # Merge neurons with high IoU.
        piecedneurons_1, times_piecedneurons_1 = \
            piece_neurons_IOU(groupedneurons, thresh_mask, thresh_IOU, times_groupedneurons)
        # Merge neurons with high consume ratio.
        piecedneurons, times_piecedneurons = \
            piece_neurons_consume(piecedneurons_1, avgArea, thresh_mask, thresh_consume, times_piecedneurons_1)
        # %% Final result
        masks_final_2 = piecedneurons
        times_final = [np.unique(x) for x in times_piecedneurons]

        # %% Refine neurons using consecutive occurence
        if masks_final_2.size:
            masks_final_2 = [x for x in masks_final_2]
            Masks_2 = [(x >= x.max() * thresh_mask).astype('float')
                       for x in masks_final_2]
            area = np.array([x.nnz for x in Masks_2])
            # Since this function is used for parameter optimization, searching "cons" will be
            # done in the next step. Here, we just assume all masks are valid neurons.
            have_cons = np.ones(len(masks_final_2), dtype='bool')
        else:
            Masks_2 = []
            area = np.array([])
            have_cons = np.array([])

    else:
        Masks_2 = []
        masks_final_2 = []
        times_final = times_uniques
        area = np.array([])
        have_cons = np.array([])

    return Masks_2, masks_final_2, times_final, area, have_cons
示例#6
0
def complete_segment(pmaps: np.ndarray,
                     Params: dict,
                     useMP=True,
                     useWT=False,
                     display=False,
                     p=None):
    '''Complete post-processing procedure. 
        This can be run after or before probablity thresholding, depending on whether Params['thresh_pmap'] is None.
        It first thresholds the "pmaps" (if Params['thresh_pmap'] is not None) into binary array, 
        then seperates the active pixels into connected regions, disgards regions smaller than Params['minArea'], 
        uses optional watershed (if useWT=True) to further segment regions larger than Params['avgArea'],
        merge the regions from different frames with close COM, large IoU, or large consume ratio,
        and finally selects masks that are active for at least Params['cons'] frames. 
        The output are "Masks_2", a 2D sparse matrix of the final segmented neurons,
        and "times_cons", a list of indices of frames when the final neuron is active.

    Inputs: 
        pmaps (3D numpy.ndarray of uint8, shape = (nframes,Lx,Ly)): the probability map obtained after CNN inference.
            If Params['thresh_pmap']==None, pmaps must be previously thresholded.
        Params (dict): Parameters for post-processing.
            Params['minArea']: Minimum area of a valid neuron mask (unit: pixels).
            Params['avgArea']: The typical neuron area (unit: pixels).
            Params['thresh_pmap']: The probablity threshold. Values higher than thresh_pmap are active pixels. 
                if Params['thresh_pmap']==None, then thresholding is not performed. 
                This is used when thresholding is done before this function.
            Params['thresh_mask']: Threashold to binarize the real-number mask.
            Params['thresh_COM0']: Threshold of COM distance (unit: pixels) used for the first COM-based merging. 
            Params['thresh_COM']: Threshold of COM distance (unit: pixels) used for the second COM-based merging. 
            Params['thresh_IOU']: Threshold of IOU used for merging neurons.
            Params['thresh_consume']: Threshold of consume ratio used for merging neurons.
            Params['cons']: Minimum number of consecutive frames that a neuron should be active for.
        useMP (bool, defaut to True): indicator of whether multiprocessing is used to speed up. 
        useWT (bool, default to False): Indicator of whether watershed is used. 
        display (bool, default to False): Indicator of whether to show intermediate information
        p (multiprocessing.Pool, default to None): 

    Outputs:
        Masks_2 (sparse.csr_matrix of bool): the final segmented binary neuron masks after consecutive refinement. 
        times_cons (list of 1D numpy.array): indices of frames when the final neuron is active.
    '''
    dims = pmaps.shape
    (nframes, Lx, Ly) = dims
    minArea = Params['minArea']
    avgArea = Params['avgArea']
    thresh_pmap = Params['thresh_pmap']
    thresh_mask = Params['thresh_mask']
    thresh_COM0 = Params['thresh_COM0']
    thresh_COM = Params['thresh_COM']
    thresh_IOU = Params['thresh_IOU']
    thresh_consume = Params['thresh_consume']
    cons = Params['cons']
    start_all = time.time()

    # Segment neuron masks from each frame of probability map
    start = time.time()
    if useMP:
        segs = p.starmap(separate_neuron,
                         [(frame, thresh_pmap, minArea, avgArea, useWT)
                          for frame in pmaps],
                         chunksize=1)
    else:
        segs = [
            separate_neuron(frame, thresh_pmap, minArea, avgArea, useWT)
            for frame in pmaps
        ]
    end = time.time()
    num_neurons = sum([x[1].size for x in segs])
    if display:
        print('{:25s}: Used {:9.6f} s, {:9.6f} ms/frame, '\
            .format('separate Neurons', end-start,(end-start)/nframes*1000),
                '{:6d} segmented neurons.'.format(num_neurons))

    if num_neurons == 0:
        print('No masks found. Please lower minArea or thresh_pmap.')
        Masks_2 = sparse.csc_matrix((0, Lx * Ly), dtype='bool')
        times_cons = []
    else:  # find active neurons
        start = time.time()
        # Initally merge neurons with close COM.
        totalmasks, neuronstate, COMs, areas, probmapID = segs_results(segs)
        uniques, times_uniques = unique_neurons2_simp(totalmasks, neuronstate, COMs, \
            areas, probmapID, minArea=0, thresh_COM0=thresh_COM0)
        end_unique = time.time()
        if display:
            print('{:25s}: Used {:9.6f} s, {:9.6f} ms/frame, '\
                .format('unique_neurons1', end_unique - start, (end_unique - start) / nframes * 1000),\
                    '{:6d} segmented neurons.'.format(len(times_uniques)))

        # Further merge neurons with close COM.
        groupedneurons, times_groupedneurons = \
            group_neurons(uniques, thresh_COM, thresh_mask, (dims[1], dims[2]), times_uniques)
        end_COM = time.time()
        if display:
            print('{:25s}: Used {:9.6f} s, {:9.6f} ms/frame, '\
                .format('group_neurons', end_COM - end_unique, (end_COM - end_unique) / nframes * 1000),\
                    '{:6d} segmented neurons.'.format(len(times_groupedneurons)))

        # Merge neurons with high IoU.
        piecedneurons_1, times_piecedneurons_1 = \
            piece_neurons_IOU(groupedneurons, thresh_mask, thresh_IOU, times_groupedneurons)
        end_IOU = time.time()
        if display:
            print('{:25s}: Used {:9.6f} s, {:9.6f} ms/frame, '\
                .format('piece_neurons_IOU', end_IOU - end_COM, (end_IOU - end_COM) / nframes * 1000),\
                    '{:6d} segmented neurons.'.format(len(times_piecedneurons_1)))

        # Merge neurons with high consume ratio.
        piecedneurons, times_piecedneurons = \
            piece_neurons_consume(piecedneurons_1, avgArea, thresh_mask, thresh_consume, times_piecedneurons_1)
        end_consume = time.time()
        if display:
            print('{:25s}: Used {:9.6f} s, {:9.6f} ms/frame, '\
                .format('piece_neurons_consume', end_consume - end_IOU, (end_consume - end_IOU) / nframes * 1000),\
                    '{:6d} segmented neurons.'.format(len(times_piecedneurons)))

        masks_final_2 = piecedneurons
        times_final = [np.unique(x) for x in times_piecedneurons]

        # Refine neurons using consecutive occurence requirement
        start = time.time()
        Masks_2, times_cons = refine_seperate(masks_final_2, times_final, cons,
                                              thresh_mask)
        end_all = time.time()
        if display:
            print('{:25s}: Used {:9.6f} s, {:9.6f} ms/frame, '\
                .format('refine_seperate', end_all - start, (end_all - start) / nframes * 1000),\
                    '{:6d} segmented neurons.'.format(len(times_final)))
            print('{:25s}: Used {:9.6f} s, {:9.6f} ms/frame, '\
                .format('Total time', end_all - start_all, (end_all - start_all) / nframes * 1000),\
                    '{:6d} segmented neurons.'.format(len(times_final)))

    return Masks_2, times_cons