Exemple #1
0
    def mrcnn():
        summary_image = np.load(
            '/home/nel/data/voltage_data/volpy_paper/memory/summary.npz'
        )['arr_0']
        config = neurons.NeuronsConfig()

        class InferenceConfig(config.__class__):
            # Run detection on one image at a time
            GPU_COUNT = 1
            IMAGES_PER_GPU = 1
            DETECTION_MIN_CONFIDENCE = 0.7
            IMAGE_RESIZE_MODE = "pad64"
            IMAGE_MAX_DIM = 512
            RPN_NMS_THRESHOLD = 0.7
            POST_NMS_ROIS_INFERENCE = 1000

        config = InferenceConfig()
        config.display()
        model_dir = os.path.join(caiman_datadir(), 'model')
        DEVICE = "/cpu:0"  # /cpu:0 or /gpu:0

        with tf.device(DEVICE):
            model = modellib.MaskRCNN(mode="inference",
                                      model_dir=model_dir,
                                      config=config)

        weights_path = download_model('mask_rcnn')
        model.load_weights(weights_path, by_name=True)

        results = model.detect([summary_image], verbose=1)
        r = results[0]
        ROIs = r['masks'].transpose([2, 0, 1])
        return ROIs
Exemple #2
0
def mrcnn_inference(img, weights_path, display_result=True):
    """ Mask R-CNN inference in VolPy
    Args: 
        img: 2-D array
            summary images for detection
            
        weights_path: str
            path for Mask R-CNN weight
            
        display_result: boolean
            if True, the function will plot the result of inference
        
    Return:
        ROIs: 3-D array
            region of interests 
            (# of components * # of pixels in x dim * # of pixels in y dim)
    """
    from caiman.source_extraction.volpy.mrcnn import visualize, neurons
    import caiman.source_extraction.volpy.mrcnn.model as modellib
    config = neurons.NeuronsConfig()

    class InferenceConfig(config.__class__):
        # Run detection on one img at a time
        GPU_COUNT = 1
        IMAGES_PER_GPU = 1
        DETECTION_MIN_CONFIDENCE = 0.7
        IMAGE_RESIZE_MODE = "pad64"
        IMAGE_MAX_DIM = 512
        RPN_NMS_THRESHOLD = 0.7
        POST_NMS_ROIS_INFERENCE = 1000

    config = InferenceConfig()
    config.display()
    model_dir = os.path.join(caiman_datadir(), 'model')
    DEVICE = "/cpu:0"  # /cpu:0 or /gpu:0
    with tf.device(DEVICE):
        model = modellib.MaskRCNN(mode="inference",
                                  model_dir=model_dir,
                                  config=config)
    model.load_weights(weights_path, by_name=True)
    results = model.detect([img], verbose=1)
    r = results[0]
    ROIs = r['masks'].transpose([2, 0, 1])

    if display_result:
        _, ax = plt.subplots(1, 1, figsize=(16, 16))
        visualize.display_instances(img,
                                    r['rois'],
                                    r['masks'],
                                    r['class_ids'], ['BG', 'neurons'],
                                    r['scores'],
                                    ax=ax,
                                    title="Predictions")
    return ROIs
Exemple #3
0
def test_computational_performance(fnames, path_ROIs, n_processes):
    import os
    import cv2
    import glob
    import logging
    import matplotlib.pyplot as plt
    import numpy as np
    import tensorflow as tf
    import h5py
    from time import time

    try:
        cv2.setNumThreads(0)
    except:
        pass

    try:
        if __IPYTHON__:
            # this is used for debugging purposes only. allows to reload classes
            # when changed
            get_ipython().magic('load_ext autoreload')
            get_ipython().magic('autoreload 2')
    except NameError:
        pass

    import caiman as cm
    from caiman.motion_correction import MotionCorrect
    from caiman.utils.utils import download_demo, download_model
    from caiman.source_extraction.volpy.volparams import volparams
    from caiman.source_extraction.volpy.volpy import VOLPY
    from caiman.source_extraction.volpy.mrcnn import visualize, neurons
    import caiman.source_extraction.volpy.mrcnn.model as modellib
    from caiman.paths import caiman_datadir
    from caiman.summary_images import local_correlations_movie_offline
    from caiman.summary_images import mean_image
    from caiman.source_extraction.volpy.utils import quick_annotation
    from multiprocessing import Pool

    time_start = time()
    print('Start MOTION CORRECTION')

    # %%  Load demo movie and ROIs
    fnames = fnames
    path_ROIs = path_ROIs

    #%% dataset dependent parameters
    # dataset dependent parameters
    fr = 400  # sample rate of the movie

    # motion correction parameters
    pw_rigid = False  # flag for pw-rigid motion correction
    gSig_filt = (3, 3)  # size of filter, in general gSig (see below),
    # change this one if algorithm does not work
    max_shifts = (5, 5)  # maximum allowed rigid shift
    strides = (
        48, 48
    )  # start a new patch for pw-rigid motion correction every x pixels
    overlaps = (24, 24
                )  # overlap between pathes (size of patch strides+overlaps)
    max_deviation_rigid = 3  # maximum deviation allowed for patch with respect to rigid shifts
    border_nan = 'copy'

    opts_dict = {
        'fnames': fnames,
        'fr': fr,
        'pw_rigid': pw_rigid,
        'max_shifts': max_shifts,
        'gSig_filt': gSig_filt,
        'strides': strides,
        'overlaps': overlaps,
        'max_deviation_rigid': max_deviation_rigid,
        'border_nan': border_nan
    }

    opts = volparams(params_dict=opts_dict)

    # %% start a cluster for parallel processing
    dview = Pool(n_processes)
    #dview = None
    # %%% MOTION CORRECTION
    # first we create a motion correction object with the specified parameters
    mc = MotionCorrect(fnames, dview=dview, **opts.get_group('motion'))
    # Run correction
    mc.motion_correct(save_movie=True)

    time_mc = time() - time_start
    print(time_mc)
    print('START MEMORY MAPPING')

    # %% restart cluster to clean up memory
    dview.terminate()
    dview = Pool(n_processes)

    # %% MEMORY MAPPING
    border_to_0 = 0 if mc.border_nan == 'copy' else mc.border_to_0
    # you can include the boundaries of the FOV if you used the 'copy' option
    # during motion correction, although be careful about the components near
    # the boundaries

    # memory map the file in order 'C'
    fname_new = cm.save_memmap_join(mc.mmap_file,
                                    base_name='memmap_',
                                    add_to_mov=border_to_0,
                                    dview=dview,
                                    n_chunks=1000)  # exclude border

    time_mmap = time() - time_start - time_mc
    print('Start Segmentation')
    # %% SEGMENTATION
    # create summary images
    img = mean_image(mc.mmap_file[0], window=1000, dview=dview)
    img = (img - np.mean(img)) / np.std(img)
    Cn = local_correlations_movie_offline(mc.mmap_file[0],
                                          fr=fr,
                                          window=1500,
                                          stride=1500,
                                          winSize_baseline=400,
                                          remove_baseline=True,
                                          dview=dview).max(axis=0)
    img_corr = (Cn - np.mean(Cn)) / np.std(Cn)
    summary_image = np.stack([img, img, img_corr], axis=2).astype(np.float32)

    #%% three methods for segmentation
    methods_list = [
        'manual_annotation',  # manual annotation needs user to prepare annotated datasets same format as demo ROIs 
        'quick_annotation',  # quick annotation annotates data with simple interface in python
        'maskrcnn'
    ]  # maskrcnn is a convolutional network trained for finding neurons using summary images
    method = methods_list[0]
    if method == 'manual_annotation':
        with h5py.File(path_ROIs, 'r') as fl:
            ROIs = fl['mov'][()]  # load ROIs

    elif method == 'quick_annotation':
        ROIs = quick_annotation(img_corr, min_radius=4, max_radius=10)

    elif method == 'maskrcnn':
        config = neurons.NeuronsConfig()

        class InferenceConfig(config.__class__):
            # Run detection on one image at a time
            GPU_COUNT = 1
            IMAGES_PER_GPU = 1
            DETECTION_MIN_CONFIDENCE = 0.7
            IMAGE_RESIZE_MODE = "pad64"
            IMAGE_MAX_DIM = 512
            RPN_NMS_THRESHOLD = 0.7
            POST_NMS_ROIS_INFERENCE = 1000

        config = InferenceConfig()
        config.display()
        model_dir = os.path.join(caiman_datadir(), 'model')
        DEVICE = "/cpu:0"  # /cpu:0 or /gpu:0
        with tf.device(DEVICE):
            model = modellib.MaskRCNN(mode="inference",
                                      model_dir=model_dir,
                                      config=config)
        weights_path = download_model('mask_rcnn')
        model.load_weights(weights_path, by_name=True)
        results = model.detect([summary_image], verbose=1)
        r = results[0]
        ROIs = r['masks'].transpose([2, 0, 1])

        display_result = False
        if display_result:
            _, ax = plt.subplots(1, 1, figsize=(16, 16))
            visualize.display_instances(summary_image,
                                        r['rois'],
                                        r['masks'],
                                        r['class_ids'], ['BG', 'neurons'],
                                        r['scores'],
                                        ax=ax,
                                        title="Predictions")

    time_seg = time() - time_mmap - time_mc - time_start
    print('Start SPIKE EXTRACTION')

    # %% restart cluster to clean up memory
    dview.terminate()
    dview = Pool(n_processes, maxtasksperchild=1)

    # %% parameters for trace denoising and spike extraction
    fnames = fname_new  # change file
    ROIs = ROIs  # region of interests
    index = list(range(len(ROIs)))  # index of neurons
    weights = None  # reuse spatial weights

    tau_lp = 5  # parameter for high-pass filter to remove photobleaching
    threshold = 4  # threshold for finding spikes, increase threshold to find less spikes
    contextSize = 35  # number of pixels surrounding the ROI to censor from the background PCA
    flip_signal = True  # Important! Flip signal or not, True for Voltron indicator, False for others

    opts_dict = {
        'fnames': fnames,
        'ROIs': ROIs,
        'index': index,
        'weights': weights,
        'tau_lp': tau_lp,
        'threshold': threshold,
        'contextSize': contextSize,
        'flip_signal': flip_signal
    }

    opts.change_params(params_dict=opts_dict)

    #%% Trace Denoising and Spike Extraction
    vpy = VOLPY(n_processes=n_processes, dview=dview, params=opts)
    vpy.fit(n_processes=n_processes, dview=dview)

    # %% STOP CLUSTER and clean up log files
    #dview.terminate()
    log_files = glob.glob('*_LOG_*')
    for log_file in log_files:
        os.remove(log_file)

    time_ext = time() - time_mmap - time_mc - time_start - time_seg

    #%%
    print('file:' + fnames)
    print('number of processes' + str(n_processes))
    print(time_mc)
    print(time_mmap)
    print(time_seg)
    print(time_ext)
    time_list = [time_mc, time_mmap, time_seg, time_ext]

    return time_list
def main():
    pass  # For compatibility between running under Spyder and the CLI

    # %%  Load demo movie and ROIs
    fnames = download_demo('demo_voltage_imaging.hdf5', 'volpy')  # file path to movie file (will download if not present)
    path_ROIs = download_demo('demo_voltage_imaging_ROIs.hdf5', 'volpy')  # file path to ROIs file (will download if not present)

    # %% Setup some parameters for data and motion correction
    # dataset parameters
    fr = 400                                        # sample rate of the movie
    ROIs = None                                     # Region of interests
    index = None                                    # index of neurons
    weights = None                                  # reuse spatial weights by 
                                                    # opts.change_params(params_dict={'weights':vpy.estimates['weights']})
    # motion correction parameters
    pw_rigid = False                                # flag for pw-rigid motion correction
    gSig_filt = (3, 3)                              # size of filter, in general gSig (see below),
                                                    # change this one if algorithm does not work
    max_shifts = (5, 5)                             # maximum allowed rigid shift
    strides = (48, 48)                              # start a new patch for pw-rigid motion correction every x pixels
    overlaps = (24, 24)                             # overlap between pathes (size of patch strides+overlaps)
    max_deviation_rigid = 3                         # maximum deviation allowed for patch with respect to rigid shifts
    border_nan = 'copy'

    opts_dict = {
        'fnames': fnames,
        'fr': fr,
        'index': index,
        'ROIs': ROIs,
        'weights': weights,
        'pw_rigid': pw_rigid,
        'max_shifts': max_shifts,
        'gSig_filt': gSig_filt,
        'strides': strides,
        'overlaps': overlaps,
        'max_deviation_rigid': max_deviation_rigid,
        'border_nan': border_nan
    }

    opts = volparams(params_dict=opts_dict)

    # %% play the movie (optional)
    # playing the movie using opencv. It requires loading the movie in memory.
    # To close the video press q
    display_images = False

    if display_images:
        m_orig = cm.load(fnames)
        ds_ratio = 0.2
        moviehandle = m_orig.resize(1, 1, ds_ratio)
        moviehandle.play(q_max=99.5, fr=60, magnification=2)

    # %% start a cluster for parallel processing

    c, dview, n_processes = cm.cluster.setup_cluster(
        backend='local', n_processes=None, single_thread=False)

    # %%% MOTION CORRECTION
    # Create a motion correction object with the specified parameters
    mc = MotionCorrect(fnames, dview=dview, **opts.get_group('motion'))
    # Run piecewise rigid motion correction
    mc.motion_correct(save_movie=True)
    dview.terminate()

    # %% motion correction compared with original movie
    display_images = False

    if display_images:
        m_orig = cm.load(fnames)
        m_rig = cm.load(mc.mmap_file)
        ds_ratio = 0.2
        moviehandle = cm.concatenate([m_orig.resize(1, 1, ds_ratio) - mc.min_mov * mc.nonneg_movie,
                                      m_rig.resize(1, 1, ds_ratio)], axis=2)
        moviehandle.play(fr=60, q_max=99.5, magnification=2)  # press q to exit

    # % movie subtracted from the mean
        m_orig2 = (m_orig - np.mean(m_orig, axis=0))
        m_rig2 = (m_rig - np.mean(m_rig, axis=0))
        moviehandle1 = cm.concatenate([m_orig2.resize(1, 1, ds_ratio),
                                       m_rig2.resize(1, 1, ds_ratio)], axis=2)
        moviehandle1.play(fr=60, q_max=99.5, magnification=2)

   # %% Memory Mapping
    c, dview, n_processes = cm.cluster.setup_cluster(
        backend='local', n_processes=None, single_thread=False)

    border_to_0 = 0 if mc.border_nan == 'copy' else mc.border_to_0
    fname_new = cm.save_memmap_join(mc.mmap_file, base_name='memmap_',
                               add_to_mov=border_to_0, dview=dview, n_chunks=10)

    dview.terminate()

    # %% change fnames to the new motion corrected one
    opts.change_params(params_dict={'fnames': fname_new})

    # %% SEGMENTATION
    # Create mean and correlation image
    use_maskrcnn = True  # set to True to predict the ROIs using the mask R-CNN
    if not use_maskrcnn:                 # use manual annotations
        with h5py.File(path_ROIs, 'r') as fl:
            ROIs = fl['mov'][()]  # load ROIs
        opts.change_params(params_dict={'ROIs': ROIs,
                                        'index': list(range(ROIs.shape[0])),
                                        'method': 'SpikePursuit'})
    else:
        m = cm.load(mc.mmap_file[0], subindices=slice(0, 20000))
        m.fr = fr
        img = m.mean(axis=0)
        img = (img-np.mean(img))/np.std(img)
        m1 = m.computeDFF(secsWindow=1, in_place=True)[0]
        m = m - m1
        Cn = m.local_correlations(swap_dim=False, eight_neighbours=True)
        img_corr = (Cn-np.mean(Cn))/np.std(Cn)
        summary_image = np.stack([img, img, img_corr], axis=2).astype(np.float32)
        del m
        del m1

        # %%
        # Mask R-CNN
        config = neurons.NeuronsConfig()

        class InferenceConfig(config.__class__):
            # Run detection on one image at a time
            GPU_COUNT = 1
            IMAGES_PER_GPU = 1
            DETECTION_MIN_CONFIDENCE = 0.7
            IMAGE_RESIZE_MODE = "pad64"
            IMAGE_MAX_DIM = 512
            RPN_NMS_THRESHOLD = 0.7
            POST_NMS_ROIS_INFERENCE = 1000

        config = InferenceConfig()
        config.display()
        model_dir = os.path.join(caiman_datadir(), 'model')
        DEVICE = "/cpu:0"  # /cpu:0 or /gpu:0
        with tf.device(DEVICE):
            model = modellib.MaskRCNN(mode="inference", model_dir=model_dir,
                                      config=config)
        weights_path = download_model('mask_rcnn')
        model.load_weights(weights_path, by_name=True)
        results = model.detect([summary_image], verbose=1)
        r = results[0]
        ROIs_mrcnn = r['masks'].transpose([2, 0, 1])

    # %% visualize the result
        display_result = False

        if display_result:
            _, ax = plt.subplots(1,1, figsize=(16,16))
            visualize.display_instances(summary_image, r['rois'], r['masks'], r['class_ids'], 
                                    ['BG', 'neurons'], r['scores'], ax=ax,
                                    title="Predictions")

    # %% set rois
        opts.change_params(params_dict={'ROIs':ROIs_mrcnn,
                                        'index':list(range(ROIs_mrcnn.shape[0])),
                                        'method':'SpikePursuit'})

    # %% Trace Denoising and Spike Extraction
    c, dview, n_processes = cm.cluster.setup_cluster(
            backend='local', n_processes=None, single_thread=False, maxtasksperchild=1)
    vpy = VOLPY(n_processes=n_processes, dview=dview, params=opts)
    vpy.fit(n_processes=n_processes, dview=dview)

    # %% some visualization
    print(np.where(vpy.estimates['passedLocalityTest'])[0])    # neurons that pass locality test
    n = 0
    
    # Processed signal and spikes of neurons
    plt.figure()
    plt.plot(vpy.estimates['trace'][n])
    plt.plot(vpy.estimates['spikeTimes'][n],
             np.max(vpy.estimates['trace'][n]) * np.ones(vpy.estimates['spikeTimes'][n].shape),
             color='g', marker='o', fillstyle='none', linestyle='none')
    plt.title('signal and spike times')
    plt.show()

    # Location of neurons by Mask R-CNN or manual annotation
    plt.figure()
    if use_maskrcnn:
        plt.imshow(ROIs_mrcnn[n])
    else:
        plt.imshow(ROIs[n])
    mv = cm.load(fname_new)
    plt.imshow(mv.mean(axis=0),alpha=0.5)
    
    # Spatial filter created by algorithm
    plt.figure()
    plt.imshow(vpy.estimates['spatialFilter'][n])
    plt.colorbar()
    plt.title('spatial filter')
    plt.show()
    
    

    # %% STOP CLUSTER and clean up log files
    cm.stop_server(dview=dview)
    log_files = glob.glob('*_LOG_*')
    for log_file in log_files:
        os.remove(log_file)
Exemple #5
0
def run_caiman_pipeline(movie, fr, fnames, savedir, usematlabroi):
    #%%

    cpu_num = 7
    cpu_num_spikepursuit = 2
    #gsig_filt_micron = (4, 4)
    #max_shifts_micron = (6,6)
    #strides_micron = (60,60)
    #overlaps_micron = (30, 30)

    gsig_filt_micron = (4, 4)
    max_shifts_micron = (6, 6)
    strides_micron = (30, 30)
    overlaps_micron = (15, 15)

    max_deviation_rigid_micron = 4

    pixel_size = movie['movie_pixel_size']

    ROIs = None  # Region of interests
    index = None  # index of neurons
    weights = None  # reuse spatial weights by
    # opts.change_params(params_dict={'weights':vpy.estimates['weights']})
    # motion correction parameters
    pw_rigid = False  # flag for pw-rigid motion correction
    gSig_filt = tuple(
        np.asarray(np.round(np.asarray(gsig_filt_micron) / float(pixel_size)),
                   int))  # size of filter, in general gSig (see below),
    # change this one if algorithm does not work
    max_shifts = tuple(
        np.asarray(np.round(np.asarray(max_shifts_micron) / float(pixel_size)),
                   int))
    strides = tuple(
        np.asarray(np.round(np.asarray(strides_micron) / float(pixel_size)),
                   int)
    )  # start a new patch for pw-rigid motion correction every x pixels
    overlaps = tuple(
        np.asarray(np.round(np.asarray(overlaps_micron) / float(pixel_size)),
                   int)
    )  # start a new patch for pw-rigid motion correction every x pixels
    # overlap between pathes (size of patch strides+overlaps)
    max_deviation_rigid = int(
        round(max_deviation_rigid_micron / pixel_size)
    )  # maximum deviation allowed for patch with respect to rigid shifts
    border_nan = 'copy'
    opts_dict = {
        'fnames': fnames,
        'fr': fr,
        'index': index,
        'ROIs': ROIs,
        'weights': weights,
        'pw_rigid': pw_rigid,
        'max_shifts': max_shifts,
        'gSig_filt': gSig_filt,
        'strides': strides,
        'overlaps': overlaps,
        'max_deviation_rigid': max_deviation_rigid,
        'border_nan': border_nan
    }
    opts = volparams(params_dict=opts_dict)

    # %% play the movie (optional)
    # playing the movie using opencv. It requires loading the movie in memory.
    # To close the video press q
    display_images = False

    if display_images:
        m_orig = cm.load(fnames)
        ds_ratio = 0.2
        moviehandle = m_orig.resize(1, 1, ds_ratio)
        moviehandle.play(q_max=99.5, fr=60, magnification=2)

    # %% start a cluster for parallel processing

    c, dview, n_processes = cm.cluster.setup_cluster(backend='local',
                                                     n_processes=cpu_num,
                                                     single_thread=False)

    # % MOTION CORRECTION
    # Create a motion correction object with the specified parameters
    mcrig = MotionCorrect(fnames, dview=dview, **opts.get_group('motion'))
    # Run piecewise rigid motion correction
    #%
    mcrig.motion_correct(save_movie=True)
    dview.terminate()

    # % MOTION CORRECTION2
    opts.change_params({'pw_rigid': True})
    c, dview, n_processes = cm.cluster.setup_cluster(backend='local',
                                                     n_processes=cpu_num,
                                                     single_thread=False)
    # Create a motion correction object with the specified parameters
    mc = MotionCorrect(mcrig.mmap_file,
                       dview=dview,
                       **opts.get_group('motion'))
    # Run piecewise rigid motion correction
    mc.motion_correct(save_movie=True)
    dview.terminate()

    # %% motion correction compared with original movie
    display_images = False
    if display_images:
        m_orig = cm.load(fnames)
        m_rig = cm.load(mcrig.mmap_file)
        m_pwrig = cm.load(mc.mmap_file)
        ds_ratio = 0.2
        moviehandle = cm.concatenate([
            m_orig.resize(1, 1, ds_ratio) - mc.min_mov * mc.nonneg_movie,
            m_rig.resize(1, 1, ds_ratio),
            m_pwrig.resize(1, 1, ds_ratio)
        ],
                                     axis=2)
        moviehandle.play(fr=60, q_max=99.5, magnification=2)  # press q to exit
        # % movie subtracted from the mean
        m_orig2 = (m_orig - np.mean(m_orig, axis=0))
        m_rig2 = (m_rig - np.mean(m_rig, axis=0))
        m_pwrig2 = (m_pwrig - np.mean(m_pwrig, axis=0))
        moviehandle1 = cm.concatenate([
            m_orig2.resize(1, 1, ds_ratio),
            m_rig2.resize(1, 1, ds_ratio),
            m_pwrig2.resize(1, 1, ds_ratio)
        ],
                                      axis=2)
        moviehandle1.play(fr=60, q_max=99.5, magnification=2)

    # %% Memory Mapping
    c, dview, n_processes = cm.cluster.setup_cluster(backend='local',
                                                     n_processes=cpu_num,
                                                     single_thread=False)
    border_to_0 = 0 if mc.border_nan == 'copy' else mc.border_to_0
    fname_new = cm.save_memmap_join(mc.mmap_file,
                                    base_name='memmap_',
                                    add_to_mov=border_to_0,
                                    dview=dview,
                                    n_chunks=10)
    dview.terminate()

    # %% change fnames to the new motion corrected one
    opts.change_params(params_dict={'fnames': fname_new})

    # %% SEGMENTATION

    roidir = savedir[:savedir.find('VolPy')] + 'Spikepursuit' + savedir[
        savedir.find('VolPy') + len('Volpy'):]
    try:
        files = os.listdir(roidir)
    except:
        files = []
    if usematlabroi and 'ROIs.mat' in files:
        ROIs = loadmat(os.path.join(roidir, 'ROIs.mat'))['ROIs']
        if len(np.shape(ROIs)) == 3:
            ROIs = np.moveaxis(np.asarray(ROIs, bool), 2, 0)
        else:
            ROIs = np.asarray([ROIs])
        all_rois = ROIs
        opts.change_params(
            params_dict={
                'ROIs': ROIs,
                'index': list(range(ROIs.shape[0])),
                'method': 'SpikePursuit'
            })

    else:
        #%
        print('WTF')
        # Create mean and correlation image
        use_maskrcnn = True  # set to True to predict the ROIs using the mask R-CNN
        if not use_maskrcnn:  # use manual annotations
            with h5py.File(path_ROIs, 'r') as fl:
                ROIs = fl['mov'][()]  # load ROIs
            opts.change_params(
                params_dict={
                    'ROIs': ROIs,
                    'index': list(range(ROIs.shape[0])),
                    'method': 'SpikePursuit'
                })
        else:
            try:
                m = cm.load(mc.mmap_file[0], subindices=slice(0, 20000))
            except:
                m = cm.load(
                    '/home/rozmar/Data/Voltage_imaging/Voltage_rig_1P/rozsam/20200120/40x_1xtube_10A_7_000_rig__d1_128_d2_512_d3_1_order_F_frames_2273_._els__d1_128_d2_512_d3_1_order_F_frames_2273_.mmap',
                    subindices=slice(0, 20000))
            m.fr = fr
            img = m.mean(axis=0)
            img = (img - np.mean(img)) / np.std(img)
            m1 = m.computeDFF(secsWindow=1, in_place=True)[0]
            m = m - m1
            Cn = m.local_correlations(swap_dim=False, eight_neighbours=True)
            img_corr = (Cn - np.mean(Cn)) / np.std(Cn)
            summary_image = np.stack([img, img, img_corr],
                                     axis=2).astype(np.float32)
            del m
            del m1

            # %
            # Mask R-CNN
            config = neurons.NeuronsConfig()

            class InferenceConfig(config.__class__):
                # Run detection on one image at a time
                GPU_COUNT = 1
                IMAGES_PER_GPU = 1
                DETECTION_MIN_CONFIDENCE = 0.7
                IMAGE_RESIZE_MODE = "pad64"
                IMAGE_MAX_DIM = 512
                RPN_NMS_THRESHOLD = 0.7
                POST_NMS_ROIS_INFERENCE = 1000

            config = InferenceConfig()
            config.display()
            model_dir = os.path.join(caiman_datadir(), 'model')
            DEVICE = "/cpu:0"  # /cpu:0 or /gpu:0
            with tf.device(DEVICE):
                model = modellib.MaskRCNN(mode="inference",
                                          model_dir=model_dir,
                                          config=config)
            weights_path = download_model('mask_rcnn')
            model.load_weights(weights_path, by_name=True)
            results = model.detect([summary_image], verbose=1)
            r = results[0]
            ROIs_mrcnn = r['masks'].transpose([2, 0, 1])

            # %% visualize the result
            display_result = False
            if display_result:
                _, ax = plt.subplots(1, 1, figsize=(16, 16))
                visualize.display_instances(summary_image,
                                            r['rois'],
                                            r['masks'],
                                            r['class_ids'], ['BG', 'neurons'],
                                            r['scores'],
                                            ax=ax,
                                            title="Predictions")
        # %% set rois
            opts.change_params(
                params_dict={
                    'ROIs': ROIs_mrcnn,
                    'index': list(range(ROIs_mrcnn.shape[0])),
                    'method': 'SpikePursuit'
                })
            #all_rois = ROIs_mrcnn

    # %% Trace Denoising and Spike Extraction

    c, dview, n_processes = cm.cluster.setup_cluster(
        backend='local',
        n_processes=cpu_num_spikepursuit,
        single_thread=False,
        maxtasksperchild=1)
    #dview=None
    vpy = VOLPY(n_processes=n_processes, dview=dview, params=opts)
    vpy.fit(n_processes=n_processes, dview=dview)

    #%%
    print('saving parameters')
    parameters = dict()
    parameters['motion'] = opts.motion
    parameters['data'] = opts.data
    parameters['volspike'] = opts.volspike
    with open(os.path.join(savedir, 'parameters.pickle'), 'wb') as outfile:
        pickle.dump(parameters, outfile)
    #%%
    volspikedata = dict()
    volspikedata['estimates'] = vpy.estimates
    volspikedata['params'] = vpy.params.data
    with open(os.path.join(savedir, 'spikepursuit.pickle'), 'wb') as outfile:
        pickle.dump(volspikedata, outfile)
    #%%

    for mcidx, mc_now in enumerate([mcrig, mc]):
        motioncorr = dict()
        motioncorr['fname'] = mc_now.fname
        motioncorr['fname_tot_rig'] = mc_now.fname_tot_rig
        motioncorr['mmap_file'] = mc_now.mmap_file
        motioncorr['min_mov'] = mc_now.min_mov
        motioncorr['shifts_rig'] = mc_now.shifts_rig
        motioncorr['shifts_opencv'] = mc_now.shifts_opencv
        motioncorr['niter_rig'] = mc_now.niter_rig
        motioncorr['min_mov'] = mc_now.min_mov
        motioncorr['templates_rig'] = mc_now.templates_rig
        motioncorr['total_template_rig'] = mc_now.total_template_rig
        try:
            motioncorr['x_shifts_els'] = mc_now.x_shifts_els
            motioncorr['y_shifts_els'] = mc_now.y_shifts_els
        except:
            pass
        with open(
                os.path.join(savedir, 'motion_corr_' + str(mcidx) + '.pickle'),
                'wb') as outfile:
            pickle.dump(motioncorr, outfile)
    #%% saving stuff
    print('moving files')
    for mmap_file in mcrig.mmap_file:
        fname = pathlib.Path(mmap_file).name
        os.remove(mmap_file)
        #shutil.move(mmap_file, os.path.join(savedir,fname))
    for mmap_file in mc.mmap_file:
        fname = pathlib.Path(mmap_file).name
        os.remove(mmap_file)
        #shutil.move(mmap_file, os.path.join(savedir,fname))

    fname = pathlib.Path(fname_new).name
    shutil.move(fname_new, os.path.join(savedir, fname))
    #print('waiting')
    #time.sleep(1000)
    # %% some visualization
    plotstuff = False
    if plotstuff:
        print(np.where(vpy.estimates['passedLocalityTest'])
              [0])  # neurons that pass locality test
        n = 0

        # Processed signal and spikes of neurons
        plt.figure()
        plt.plot(vpy.estimates['trace'][n])
        plt.plot(vpy.estimates['spikeTimes'][n],
                 np.max(vpy.estimates['trace'][n]) *
                 np.ones(vpy.estimates['spikeTimes'][n].shape),
                 color='g',
                 marker='o',
                 fillstyle='none',
                 linestyle='none')
        plt.title('signal and spike times')
        plt.show()
        # Location of neurons by Mask R-CNN or manual annotation
        plt.figure()
        if use_maskrcnn:
            plt.imshow(ROIs_mrcnn[n])
        else:
            plt.imshow(ROIs[n])
        mv = cm.load(fname_new)
        plt.imshow(mv.mean(axis=0), alpha=0.5)

        # Spatial filter created by algorithm
        plt.figure()
        plt.imshow(vpy.estimates['spatialFilter'][n])
        plt.colorbar()
        plt.title('spatial filter')
        plt.show()

    # %% STOP CLUSTER and clean up log files

    cm.stop_server(dview=dview)
    log_files = glob.glob('*_LOG_*')
    for log_file in log_files:
        os.remove(log_file)