Esempio n. 1
0
def suns_batch(dir_video, Exp_ID, filename_CNN, Params_pre, Params_post, dims, \
        batch_size_eval=1, useSF=True, useTF=True, useSNR=True, med_subtract=False, \
        useWT=False, prealloc=True, display=True, useMP=True, p=None):
    '''The complete SUNS batch procedure.
        It uses the trained CNN model from "filename_CNN" and the optimized hyper-parameters in "Params_post"
        to process the video "Exp_ID" in "dir_video"

    Inputs: 
        dir_video (str): The folder containing the input video.
            Each file must be a ".h5" file, with dataset "mov" being the input video (shape = (T0,Lx0,Ly0)).
        Exp_ID (str): The filer name of the input raw video. 
        filename_CNN (str): The path of the trained CNN model. 
        Params_pre (dict): Parameters for pre-processing.
            Params_pre['gauss_filt_size'] (float): The standard deviation of the spatial Gaussian filter in pixels
            Params_pre['Poisson_filt'] (1D numpy.ndarray of float32): The temporal filter kernel
            Params_pre['num_median_approx'] (int): Number of frames used to compute 
                the median and median-based standard deviation
            Params_pre['nn'] (int): Number of frames at the beginning of the video to be processed.
                The remaining video is not considered a part of the input video.
        Params_post (dict): Parameters for post-processing.
            Params_post['minArea']: Minimum area of a valid neuron mask (unit: pixels).
            Params_post['avgArea']: The typical neuron area (unit: pixels).
            Params_post['thresh_pmap']: The probablity threshold. Values higher than thresh_pmap are active pixels. 
                It is stored in uint8, so it should be converted to float32 before using.
            Params_post['thresh_mask']: Threashold to binarize the real-number mask.
            Params_post['thresh_COM0']: Threshold of COM distance (unit: pixels) used for the first COM-based merging. 
            Params_post['thresh_COM']: Threshold of COM distance (unit: pixels) used for the second COM-based merging. 
            Params_post['thresh_IOU']: Threshold of IOU used for merging neurons.
            Params_post['thresh_consume']: Threshold of consume ratio used for merging neurons.
            Params_post['cons']: Minimum number of consecutive frames that a neuron should be active for.
        dims (tuplel of int, shape = (2,)): lateral dimension of the raw video.
        batch_size_eval (int, default to 1): batch size of CNN inference.
        useSF (bool, default to True): True if spatial filtering is used.
        useTF (bool, default to True): True if temporal filtering is used.
        useSNR (bool, default to True): True if pixel-by-pixel SNR normalization filtering is used.
        med_subtract (bool, default to False): True if the spatial median of every frame is subtracted before temporal filtering.
            Can only be used when spatial filtering is not used. 
        useWT (bool, default to False): Indicator of whether watershed is used. 
        prealloc (bool, default to True): True if pre-allocate memory space for large variables. 
            Achieve faster speed at the cost of higher memory occupation.
        display (bool, default to True): Indicator of whether to show intermediate information
        useMP (bool, defaut to True): indicator of whether multiprocessing is used to speed up. 
        p (multiprocessing.Pool, default to None): 

    Outputs:
        Masks (3D numpy.ndarray of bool, shape = (n,Lx0,Ly0)): the final segmented masks. 
        Masks_2 (scipy.csr_matrix of bool, shape = (n,Lx0*Ly0)): the final segmented masks in the form of sparse matrix. 
        time_total (list of float, shape = (4,)): the total time spent 
            for pre-processing, CNN inference, post-processing, and total processing
        time_frame (list of float, shape = (4,)): the average time spent on every frame
            for pre-processing, CNN inference, post-processing, and total processing
    '''
    if display:
        start = time.time()
    (Lx, Ly) = dims
    rowspad = math.ceil(Lx / 16) * 16
    colspad = math.ceil(Ly / 16) * 16
    # load CNN model
    fff = get_shallow_unet()
    fff.load_weights(filename_CNN)
    # run CNN inference once to warm up
    init_imgs = np.zeros((batch_size_eval, rowspad, colspad, 1),
                         dtype='float32')
    init_masks = np.zeros((batch_size_eval, rowspad, colspad, 1),
                          dtype='uint8')
    fff.evaluate(init_imgs, init_masks, batch_size=batch_size_eval)
    del init_imgs, init_masks

    thresh_pmap_float = (Params_post['thresh_pmap'] + 1.5) / 256
    # thresh_pmap_float = (Params_post['thresh_pmap']+1)/256 # for published version
    if display:
        time_init = time.time()
        print('Initialization time: {} s'.format(time_init - start))

    # %% Actual processing starts after the video is loaded into memory
    # which is in the middle of "preprocess_video", represented by the output "start"
    # pre-processing including loading data
    video_input, start = preprocess_video(dir_video, Exp_ID, Params_pre, \
        useSF=useSF, useTF=useTF, useSNR=useSNR, med_subtract=med_subtract, prealloc=prealloc, display=display)
    nframes = video_input.shape[0]
    if display:
        end_pre = time.time()
        time_pre = end_pre - start
        time_frame_pre = time_pre / nframes * 1000
        print('Pre-Processing time: {:6f} s, {:6f} ms/frame'.format(
            time_pre, time_frame_pre))

    # CNN inference
    video_input = np.expand_dims(video_input, axis=-1)
    prob_map = fff.predict(video_input, batch_size=batch_size_eval)
    if display:
        end_network = time.time()
        time_CNN = end_network - end_pre
        time_frame_CNN = time_CNN / nframes * 1000
        print('CNN Infrence time: {:6f} s, {:6f} ms/frame'.format(
            time_CNN, time_frame_CNN))

    # post-processing
    prob_map = prob_map.squeeze()[:, :Lx, :Ly]
    print(Params_post)
    Params_post_copy = Params_post.copy()
    Params_post_copy[
        'thresh_pmap'] = None  # Avoid repeated thresholding in postprocessing
    pmaps_b = np.zeros(prob_map.shape, dtype='uint8')
    # threshold the probability map to binary activity
    fastthreshold(prob_map, pmaps_b, thresh_pmap_float)

    # the rest of post-processing. The result is a 2D sparse matrix of the segmented neurons
    Masks_2 = complete_segment(pmaps_b,
                               Params_post_copy,
                               display=display,
                               p=p,
                               useWT=useWT)
    if display:
        finish = time.time()
        time_post = finish - end_network
        time_frame_post = time_post / nframes * 1000
        print('Post-Processing time: {:6f} s, {:6f} ms/frame'.format(
            time_post, time_frame_post))

    # convert to a 3D array of the segmented neurons
    Masks = np.reshape(Masks_2.toarray(),
                       (Masks_2.shape[0], Lx, Ly)).astype('bool')

    # Save total processing time, and average processing time per frame
    if display:
        time_all = finish - start
        time_frame_all = time_all / nframes * 1000
        print('Total time: {:6f} s, {:6f} ms/frame'.format(
            time_all, time_frame_all))
        time_total = np.array([time_pre, time_CNN, time_post, time_all])
        time_frame = np.array(
            [time_frame_pre, time_frame_CNN, time_frame_post, time_frame_all])
    else:
        time_total = np.zeros((4, ))
        time_frame = np.zeros((4, ))

    return Masks, Masks_2, time_total, time_frame
Esempio n. 2
0
    Params_set = {
        'list_minArea': list_minArea,
        'list_avgArea': list_avgArea,
        'list_thresh_pmap': list_thresh_pmap,
        'thresh_COM0': thresh_COM0,
        'list_thresh_COM': list_thresh_COM,
        'list_thresh_IOU': list_thresh_IOU,
        'thresh_mask': thresh_mask,
        'list_cons': list_cons
    }
    print(Params_set)

    # pre-processing for training
    for Exp_ID in list_Exp_ID:  #
        # %% Pre-process video
        video_input, _ = preprocess_video(dir_video, Exp_ID, Params_pre, dir_network_input, \
            useSF=useSF, useTF=useTF, useSNR=useSNR, med_subtract=med_subtract, prealloc=prealloc) #

        # %% Determine active neurons in all frames using FISSA
        file_mask = dir_GTMasks + Exp_ID + '.mat'  # foldr to save the temporal masks
        generate_masks(video_input, file_mask, list_thred_ratio, dir_parent,
                       Exp_ID)
        del video_input

    # %% CNN training
    if cross_validation == "use_all":
        list_CV = [nvideo]
    else:
        list_CV = list(range(0, nvideo))
    for CV in list_CV:
        if cross_validation == "leave_one_out":
            list_Exp_ID_train = list_Exp_ID.copy()
Esempio n. 3
0
            # dictionary of all fixed and searched post-processing parameters.
            Params_set = {
                'list_minArea': list_minArea,
                'list_avgArea': list_avgArea,
                'list_thresh_pmap': list_thresh_pmap,
                'thresh_COM0': thresh_COM0,
                'list_thresh_COM': list_thresh_COM,
                'list_thresh_IOU': list_thresh_IOU,
                'thresh_mask': thresh_mask,
                'list_cons': list_cons
            }
            print(Params_set)

            # pre-processing for training
            # %% Pre-process video
            video_input, _ = preprocess_video(dir_video, Exp_ID, Params, dir_network_input, \
                useSF=useSF, useTF=useTF, useSNR=useSNR, prealloc=prealloc) #

            # %% Determine active neurons in all frames using FISSA
            file_mask = dir_GTMasks + Exp_ID + '.mat'  # foldr to save the temporal masks
            generate_masks(video_input, file_mask, list_thred_ratio,
                           dir_parent, Exp_ID)
            del video_input

            list_Exp_ID_train = [Exp_ID]
            list_Exp_ID_val = None  # Afternatively, we can get rid of validation steps
            file_CNN = weights_path + 'Model_{}.h5'.format(Exp_ID)
            file_CNN_2 = weights_path + 'Model_CV0.h5'
            results = train_CNN(dir_network_input, dir_mask, file_CNN, list_Exp_ID_train, list_Exp_ID_val, \
                BATCH_SIZE, NO_OF_EPOCHS, num_train_per, num_total, (rows, cols), Params_loss)
            copyfile(file_CNN, file_CNN_2)