Ejemplo n.º 1
0
    if not os.path.exists(weights_path):
        os.makedirs(weights_path)
    if not os.path.exists(training_output_path):
        os.makedirs(training_output_path)
    if not os.path.exists(dir_output):
        os.makedirs(dir_output)
    if not os.path.exists(dir_temp):
        os.makedirs(dir_temp)

    nvideo = len(list_Exp_ID)  # number of videos used for cross validation
    # Get and check the dimensions of all the videos
    list_Dimens = np.zeros((nvideo, 3), dtype='uint16')
    for (eid, Exp_ID) in enumerate(list_Exp_ID):
        h5_video = os.path.join(dir_video, Exp_ID + '.h5')
        h5_file = h5py.File(h5_video, 'r')
        dset = find_dataset(h5_file)
        list_Dimens[eid] = h5_file[dset].shape
        h5_file.close()

    nframes = np.unique(list_Dimens[:, 0])
    Lx = np.unique(list_Dimens[:, 1])
    Ly = np.unique(list_Dimens[:, 2])
    if len(Lx) * len(Ly) != 1:
        ValueError(
            '''The lateral dimensions of all the training videos must be the same in this version.
        Use another version if training videos have different dimensions''')
    nframes = nframes.min()
    rows = Lx[0]
    cols = Ly[0]

    rowspad = math.ceil(rows / 8) * 8  # size of the network input and output
Ejemplo n.º 2
0
def preprocess_video_online(dir_video:str, Exp_ID:str, Params:dict, frames_init=10**6,
        dir_network_input:str = None, useSF=False, useTF=True, useSNR=True, \
        med_subtract=False, update_baseline=False, prealloc=True, display=True):
    '''Pre-process the registered video "Exp_ID" in "dir_video" into an SNR video "network_input".
        The process includes spatial filter, temporal filter, and SNR normalization. Each step is optional.
        The function does some preparations, including pre-allocating memory space (optional), FFT planing, 
        and setting filter kernels, before loading the video. 

    Inputs: 
        dir_video (str): The folder containing the input video.
            Each file must be a ".h5" file.
            The video dataset can have any name, but cannot be under any group.
        Exp_ID (str): The filer name of the input video. 
        Params (dict): Parameters for pre-processing.
            Params['gauss_filt_size'] (float): The standard deviation of the spatial Gaussian filter in pixels
            Params['Poisson_filt'] (1D numpy.ndarray of float32): The temporal filter kernel
            Params['num_median_approx'] (int): Number of frames used to compute 
                the median and median-based standard deviation
            Params['nn'] (int): Number of frames at the beginning of the video to be processed.
                The remaining video is not considered a part of the input video.
        frames_init (int, default to a very large number): Median and median-based standard deviation are 
            calculate from the first "frames_init" frames (before temporal filtering) of the video
        dir_network_input (str, default to None): The folder to save the SNR video (network_input) in hard drive.
            If dir_network_input == None, then the SNR video is not stored in hard drive
        useSF (bool, default to True): True if spatial filtering is used.
        useTF (bool, default to True): True if temporal filtering is used.
        useSNR (bool, default to True): True if pixel-by-pixel SNR normalization filtering is used.
        med_subtract (bool, default to False): True if the spatial median of every frame is subtracted before temporal filtering.
            Can only be used when spatial filtering is not used. 
        update_baseline (bool, default to False): True if the median and median-based std is updated every "frames_init" frames.
        prealloc (bool, default to True): True if pre-allocate memory space for large variables. 
            Achieve faster speed at the cost of higher memory occupation.

    Outputs:
        network_input (numpy.ndarray of float32, shape = (T,Lx,Ly)): the SNR video obtained after pre-processing. 
            The shape is (T,Lx,Ly), where T is shorter than T0 due to temporal filtering, 
            and Lx and Ly are Lx0 and Ly0 padded to multiples of 8, so that the images can be process by the shallow U-Net.
        start (float): the starting time of the pipline (after the data is loaded into memory)
        In addition, the SNR video is saved in "dir_network_input" as a "(Exp_ID).h5" file containing a dataset "network_input".
    '''
    nn = Params['nn']
    if useSF:
        gauss_filt_size = Params['gauss_filt_size']
    Poisson_filt = Params['Poisson_filt']
    # num_median_approx = Params['num_median_approx']

    h5_video = os.path.join(dir_video, Exp_ID + '.h5')
    h5_file = h5py.File(h5_video, 'r')
    dset = find_dataset(h5_file)
    (nframes, rows, cols) = h5_file[dset].shape
    # Make the lateral number of pixels a multiple of 8, so that the CNN can process them
    rowspad = math.ceil(rows / 8) * 8
    colspad = math.ceil(cols / 8) * 8
    # Only keep the first "nn" frames to process
    nframes = min(nframes, nn)

    if useSF:
        # lateral dimensions slightly larger than the raw video but faster for FFT
        rows1 = cv2.getOptimalDFTSize(rows)
        cols1 = cv2.getOptimalDFTSize(cols)

        if display:
            print(rows, cols, nn, '->', rows1, cols1, nn)
            start_plan = time.time()
        # if the learned wisdom files have been saved, load them. Otherwise, learn wisdom later
        Length_data = str((nn, rows1, cols1))
        cc = load_wisdom_txt('wisdom\\' + Length_data)
        if cc:
            pyfftw.import_wisdom(cc)

        # FFT planning
        bb, bf, fft_object_b, fft_object_c = plan_fft(nn, (rows1, cols1),
                                                      prealloc)
        if display:
            end_plan = time.time()
            print('FFT planning: {} s'.format(end_plan - start_plan))

        # %% Initialization: Calculate the spatial filter and set variables.
        mask2 = plan_mask2((rows, cols), (rows1, cols1), gauss_filt_size)
    else:
        bb = np.zeros((nframes, rowspad, colspad), dtype='float32')
        fft_object_b = None
        fft_object_c = None
        bf = None
        end_plan = time.time()
        mask2 = None

    # %% Initialization: Set variables.
    if useTF:
        leng_tf = Poisson_filt.size
        nframesf = nframes - leng_tf + 1  # Number of frames after temporal filtering
    else:
        nframesf = nframes
    frames_initf = frames_init - leng_tf + 1
    if prealloc:
        network_input = np.ones((nframesf, rowspad, colspad), dtype='float32')
        med_frame2 = np.ones((rowspad, colspad, 2), dtype='float32')
    else:
        network_input = np.zeros((nframesf, rowspad, colspad), dtype='float32')
        med_frame2 = np.zeros((rowspad, colspad, 2), dtype='float32')
    # median_decimate = max(1,nframes//num_median_approx)
    if display:
        end_init = time.time()
        print('Initialization: {} s'.format(end_init - end_plan))

    # %% Load the raw video into "bb"
    for t in range(nframes):  # use this one to save memory
        bb[t, :rows, :cols] = np.array(h5_file[dset][t])
    h5_file.close()
    if display:
        end_load = time.time()
        print('data loading: {} s'.format(end_load - end_init))

    start = time.time(
    )  # The pipline starts after the video is loaded into memory
    network_input = preprocess_complete_online(bb, (rowspad,colspad), network_input, med_frame2, \
        Poisson_filt, mask2, bf, fft_object_b, fft_object_c, useSF, useTF, useSNR, \
        med_subtract, update_baseline, frames_initf, frames_init, prealloc, display)

    if display:
        end = time.time()
        print('total per frame: {} ms'.format((end - start) / nframes * 1000))

    # if "dir_network_input" is not None, save network_input to an ".h5" file
    if dir_network_input:
        f = h5py.File(os.path.join(dir_network_input, Exp_ID + ".h5"), "w")
        f.create_dataset("network_input", data=network_input)
        f.close()
        if display:
            end_saving = time.time()
            print('Network_input saving: {} s'.format(end_saving - end))

    return network_input, start