コード例 #1
0
def extract_masks(scan, mmap_scan, num_components=200, num_background_components=1,
                  merge_threshold=0.8, init_on_patches=True, init_method='greedy_roi',
                  soma_diameter=(14, 14), snmf_alpha=None, patch_size=(50, 50),
                  proportion_patch_overlap=0.2, num_components_per_patch=5,
                  num_processes=8, num_pixels_per_process=5000, fps=15):
    """ Extract masks from multi-photon scans using CNMF.

    Uses constrained non-negative matrix factorization to find spatial components (masks)
    and their fluorescence traces in a scan. Default values work well for somatic scans.

    Performed operations are:
        [Initialization on full image | Initialization on patches -> merge components] ->
        spatial update -> temporal update -> merge components -> spatial update ->
        temporal update

    :param np.array scan: 3-dimensional scan (image_height, image_width, num_frames).
    :param np.memmap mmap_scan: 2-d scan (image_height * image_width, num_frames)
    :param int num_components: An estimate of the number of spatial components in the scan
    :param int num_background_components: Number of components to model the background.
    :param int merge_threshold: Maximal temporal correlation allowed between the activity
        of overlapping components before merging them.
    :param bool init_on_patches: If True, run the initialization methods on small patches
        of the scan rather than on the whole image.
    :param string init_method: Initialization method for the components.
        'greedy_roi': Look for a gaussian-shaped patch, apply rank-1 NMF, store
            components, calculate residual scan and repeat for num_components.
        'sparse_nmf': Regularized non-negative matrix factorization (as impl. in sklearn)
    :param (float, float) soma_diameter: Estimated neuron size in y and x (pixels). Used
        in'greedy_roi' initialization to search for neurons of this size.
    :param int snmf_alpha: Regularization parameter (alpha) for sparse NMF (if used).
    :param (float, float) patch_size: Size of the patches in y and x (pixels).
    :param float proportion_patch_overlap: Patches are sampled in a sliding window. This
        controls how much overlap is between adjacent patches (0 for none, 0.9 for 90%).
    :param int num_components_per_patch: Number of components per patch (used if
        init_on_patches=True)
    :param int num_processes: Number of processes to run in parallel. None for as many
        processes as available cores.
    :param int num_pixels_per_process: Number of pixels that a process handles each
        iteration.
    :param fps: Frame rate. Used for temporal downsampling and to remove bad components.

    :returns: Weighted masks (image_height x image_width x num_components). Inferred
        location of each component.
    :returns: Denoised fluorescence traces (num_components x num_frames).
    :returns: Masks for background components (image_height x image_width x
        num_background_components).
    :returns: Traces for background components (image_height x image_width x
        num_background_components).
    :returns: Raw fluorescence traces (num_components x num_frames). Fluorescence of each
        component in the scan minus activity from other components and background.

    ..warning:: The produced number of components is not exactly what you ask for because
        some components will be merged or deleted.
    ..warning:: Better results if scans are nonnegative.
    """
    # Get some params
    image_height, image_width, num_frames = scan.shape

    # Start processes
    log('Starting {} processes...'.format(num_processes))
    pool = mp.Pool(processes=num_processes)

    # Initialize components
    log('Initializing components...')
    if init_on_patches:
        # TODO: Redo this (per-patch initialization) in a nicer/more efficient way

        # Make sure they are integers
        patch_size = np.array(patch_size)
        half_patch_size = np.int32(np.round(patch_size / 2))
        num_components_per_patch = int(round(num_components_per_patch))
        patch_overlap = np.int32(np.round(patch_size * proportion_patch_overlap))

        # Create options dictionary (needed for run_CNMF_patches)
        options = {'patch_params': {'ssub': 'UNUSED.', 'tsub': 'UNUSED', 'nb': num_background_components,
                                    'only_init': True, 'skip_refinement': 'UNUSED.',
                                    'remove_very_bad_comps': False}, # remove_very_bads_comps unnecesary (same as default)
                   'preprocess_params': {'check_nan': False}, # check_nan is unnecessary (same as default value)
                   'spatial_params': {'nb': num_background_components}, # nb is unnecessary, it is pased to the function and in init_params
                   'temporal_params': {'p': 0, 'method': 'UNUSED.', 'block_size': 'UNUSED.'},
                   'init_params': {'K': num_components_per_patch, 'gSig': np.array(soma_diameter)/2,
                                   'gSiz': None, 'method': init_method, 'alpha_snmf': snmf_alpha,
                                   'nb': num_background_components, 'ssub': 1, 'tsub': max(int(fps / 2), 1),
                                   'options_local_NMF': 'UNUSED.', 'normalize_init': True,
                                   'rolling_sum': True, 'rolling_length': 100, 'min_corr': 'UNUSED',
                                   'min_pnr': 'UNUSED', 'deconvolve_options_init': 'UNUSED',
                                   'ring_size_factor': 'UNUSED', 'center_psf': 'UNUSED'},
                                   # gSiz, ssub, tsub, options_local_NMF, normalize_init, rolling_sum unnecessary (same as default values)
                   'merging' : {'thr': 'UNUSED.'}}

        # Initialize per patch
        res = map_reduce.run_CNMF_patches(mmap_scan.filename, (image_height, image_width, num_frames),
                                          options, rf=half_patch_size, stride=patch_overlap,
                                          gnb=num_background_components, dview=pool)
        initial_A, initial_C, YrA, initial_b, initial_f, pixels_noise, _ = res

        # Merge spatially overlapping components
        merged_masks = ['dummy']
        while len(merged_masks) > 0:
            res = merging.merge_components(mmap_scan, initial_A, initial_b, initial_C,
                                           initial_f, initial_C, pixels_noise,
                                           {'p': 0, 'method': 'cvxpy'}, spatial_params='UNUSED',
                                           dview=pool, thr=merge_threshold, mx=np.Inf)
            initial_A, initial_C, num_components, merged_masks, S, bl, c1, neurons_noise, g = res

        # Delete log files (one per patch)
        log_files = glob.glob('caiman*_LOG_*')
        for log_file in log_files:
            os.remove(log_file)
    else:
        from scipy.sparse import csr_matrix
        if init_method == 'greedy_roi':
            res = _greedyROI(scan, num_components, soma_diameter, num_background_components)
            log('Refining initial components (HALS)...')
            res = initialization.hals(scan, res[0].reshape([image_height * image_width, -1], order='F'),
                                      res[1], res[2].reshape([image_height * image_width, -1], order='F'),
                                      res[3], maxIter=3)
            initial_A, initial_C, initial_b, initial_f = res
        else:
            print('Warning: Running sparse_nmf initialization on the entire field of view '
                  'takes a lot of time.')
            res = initialization.initialize_components(scan, K=num_components, nb=num_background_components,
                                                       method=init_method, alpha_snmf=snmf_alpha)
            initial_A, initial_C, initial_b, initial_f, _ = res
        initial_A = csr_matrix(initial_A)
    log(initial_A.shape[-1], 'components found...')

    # Remove bad components (based on spatial consistency and spiking activity)
    log('Removing bad components...')
    good_indices, _ = components_evaluation.estimate_components_quality(initial_C, scan,
        initial_A, initial_C, initial_b, initial_f, final_frate=fps, r_values_min=0.7,
        fitness_min=-20, fitness_delta_min=-20, dview=pool)
    initial_A = initial_A[:, good_indices]
    initial_C = initial_C[good_indices]
    log(initial_A.shape[-1], 'components remaining...')

    # Estimate noise per pixel
    log('Calculating noise per pixel...')
    pixels_noise, _ = pre_processing.get_noise_fft_parallel(mmap_scan, num_pixels_per_process, pool)

    # Update masks
    log('Updating masks...')
    A, b, C, f = spatial.update_spatial_components(mmap_scan, initial_C, initial_f, initial_A, b_in=initial_b,
                                                   sn=pixels_noise, dims=(image_height, image_width),
                                                   method='dilate', dview=pool,
                                                   n_pixels_per_process=num_pixels_per_process,
                                                   nb=num_background_components)

    # Update traces (no impulse response modelling p=0)
    log('Updating traces...')
    res = temporal.update_temporal_components(mmap_scan, A, b, C, f, nb=num_background_components,
                                              block_size=10000, p=0, method='cvxpy', dview=pool)
    C, A, b, f, S, bl, c1, neurons_noise, g, YrA, _ = res


    # Merge components
    log('Merging overlapping (and temporally correlated) masks...')
    merged_masks = ['dummy']
    while len(merged_masks) > 0:
        res = merging.merge_components(mmap_scan, A, b, C, f, S, pixels_noise, {'p': 0, 'method': 'cvxpy'},
                                       'UNUSED', dview=pool, thr=merge_threshold, bl=bl, c1=c1,
                                       sn=neurons_noise, g=g)
        A, C, num_components, merged_masks, S, bl, c1, neurons_noise, g = res

    # Refine masks
    log('Refining masks...')
    A, b, C, f = spatial.update_spatial_components(mmap_scan, C, f, A, b_in=b, sn=pixels_noise,
                                                   dims=(image_height, image_width),
                                                   method='dilate', dview=pool,
                                                   n_pixels_per_process=num_pixels_per_process,
                                                   nb=num_background_components)

    # Refine traces
    log('Refining traces...')
    res = temporal.update_temporal_components(mmap_scan, A, b, C, f, nb=num_background_components,
                                              block_size=10000, p=0, method='cvxpy', dview=pool)
    C, A, b, f, S, bl, c1, neurons_noise, g, YrA, _ = res

    # Removing bad components (more stringent criteria)
    log('Removing bad components...')
    good_indices, _ = components_evaluation.estimate_components_quality(C + YrA, scan, A,
        C, b, f, final_frate=fps, r_values_min=0.8, fitness_min=-40, fitness_delta_min=-40,
        dview=pool)
    A = A.toarray()[:, good_indices]
    C = C[good_indices]
    YrA = YrA[good_indices]
    log(A.shape[-1], 'components remaining...')

    # Stop processes
    log('Done.')
    pool.close()

    # Get results
    masks = A.reshape((image_height, image_width, -1), order='F') # h x w x num_components
    traces = C  # num_components x num_frames
    background_masks = b.reshape((image_height, image_width, -1), order='F') # h x w x num_components
    background_traces = f  # num_background_components x num_frames
    raw_traces = C + YrA  # num_components x num_frames

    # Rescale traces to match scan range (~ np.average(trace*mask, weights=mask))
    scaling_factor = np.sum(masks**2, axis=(0, 1)) / np.sum(masks, axis=(0, 1))
    traces = traces * np.expand_dims(scaling_factor, -1)
    raw_traces = raw_traces * np.expand_dims(scaling_factor, -1)
    masks = masks / scaling_factor
    background_scaling_factor = np.sum(background_masks**2, axis=(0, 1)) / np.sum(background_masks,
                                                                                  axis=(0,1))
    background_traces = background_traces * np.expand_dims(background_scaling_factor, -1)
    background_masks = background_masks / background_scaling_factor

    return masks, traces, background_masks, background_traces, raw_traces
コード例 #2
0
def extract_masks(scan, mmap_scan, num_components=200, num_background_components=1,
                  merge_threshold=0.8, init_on_patches=True, init_method='greedy_roi',
                  soma_diameter=(14, 14), snmf_alpha=None, patch_size=(50, 50),
                  proportion_patch_overlap=0.2, num_components_per_patch=5,
                  num_processes=8, num_pixels_per_process=5000, fps=15):
    """ Extract masks from multi-photon scans using CNMF.

    Uses constrained non-negative matrix factorization to find spatial components (masks)
    and their fluorescence traces in a scan. Default values work well for somatic scans.

    Performed operations are:
        [Initialization on full image | Initialization on patches -> merge components] ->
        spatial update -> temporal update -> merge components -> spatial update ->
        temporal update

    :param np.array scan: 3-dimensional scan (image_height, image_width, num_frames).
    :param np.memmap mmap_scan: 2-d scan (image_height * image_width, num_frames)
    :param int num_components: An estimate of the number of spatial components in the scan
    :param int num_background_components: Number of components to model the background.
    :param int merge_threshold: Maximal temporal correlation allowed between the activity
        of overlapping components before merging them.
    :param bool init_on_patches: If True, run the initialization methods on small patches
        of the scan rather than on the whole image.
    :param string init_method: Initialization method for the components.
        'greedy_roi': Look for a gaussian-shaped patch, apply rank-1 NMF, store
            components, calculate residual scan and repeat for num_components.
        'sparse_nmf': Regularized non-negative matrix factorization (as impl. in sklearn)
    :param (float, float) soma_diameter: Estimated neuron size in y and x (pixels). Used
        in'greedy_roi' initialization to search for neurons of this size.
    :param int snmf_alpha: Regularization parameter (alpha) for sparse NMF (if used).
    :param (float, float) patch_size: Size of the patches in y and x (pixels).
    :param float proportion_patch_overlap: Patches are sampled in a sliding window. This
        controls how much overlap is between adjacent patches (0 for none, 0.9 for 90%).
    :param int num_components_per_patch: Number of components per patch (used if
        init_on_patches=True)
    :param int num_processes: Number of processes to run in parallel. None for as many
        processes as available cores.
    :param int num_pixels_per_process: Number of pixels that a process handles each
        iteration.
    :param fps: Frame rate. Used for temporal downsampling and to remove bad components.

    :returns: Weighted masks (image_height x image_width x num_components). Inferred
        location of each component.
    :returns: Denoised fluorescence traces (num_components x num_frames).
    :returns: Masks for background components (image_height x image_width x
        num_background_components).
    :returns: Traces for background components (image_height x image_width x
        num_background_components).
    :returns: Raw fluorescence traces (num_components x num_frames). Fluorescence of each
        component in the scan minus activity from other components and background.

    ..warning:: The produced number of components is not exactly what you ask for because
        some components will be merged or deleted.
    ..warning:: Better results if scans are nonnegative.
    """
    # Get some params
    image_height, image_width, num_frames = scan.shape

    # Start processes
    log('Starting {} processes...'.format(num_processes))
    pool = mp.Pool(processes=num_processes)

    # Initialize components
    log('Initializing components...')
    if init_on_patches:
        # TODO: Redo this (per-patch initialization) in a nicer/more efficient way

        # Make sure they are integers
        patch_size = np.array(patch_size)
        half_patch_size = np.int32(np.round(patch_size / 2))
        num_components_per_patch = int(round(num_components_per_patch))
        patch_overlap = np.int32(np.round(patch_size * proportion_patch_overlap))

        # Create options dictionary (needed for run_CNMF_patches)
        options = {'patch_params': {'ssub': 'UNUSED.', 'tsub': 'UNUSED', 'nb': num_background_components,
                                    'only_init': True, 'skip_refinement': 'UNUSED.',
                                    'remove_very_bad_comps': False}, # remove_very_bads_comps unnecesary (same as default)
                   'preprocess_params': {'check_nan': False}, # check_nan is unnecessary (same as default value)
                   'spatial_params': {'nb': num_background_components}, # nb is unnecessary, it is pased to the function and in init_params
                   'temporal_params': {'p': 0, 'method': 'UNUSED.', 'block_size': 'UNUSED.'},
                   'init_params': {'K': num_components_per_patch, 'gSig': np.array(soma_diameter)/2,
                                   'gSiz': None, 'method': init_method, 'alpha_snmf': snmf_alpha,
                                   'nb': num_background_components, 'ssub': 1, 'tsub': max(int(fps / 2), 1),
                                   'options_local_NMF': 'UNUSED.', 'normalize_init': True,
                                   'rolling_sum': True, 'rolling_length': 100, 'min_corr': 'UNUSED',
                                   'min_pnr': 'UNUSED', 'deconvolve_options_init': 'UNUSED',
                                   'ring_size_factor': 'UNUSED', 'center_psf': 'UNUSED'},
                                   # gSiz, ssub, tsub, options_local_NMF, normalize_init, rolling_sum unnecessary (same as default values)
                   'merging' : {'thr': 'UNUSED.'}}

        # Initialize per patch
        res = map_reduce.run_CNMF_patches(mmap_scan.filename, (image_height, image_width, num_frames),
                                          options, rf=half_patch_size, stride=patch_overlap,
                                          gnb=num_background_components, dview=pool)
        initial_A, initial_C, YrA, initial_b, initial_f, pixels_noise, _ = res

        # Merge spatially overlapping components
        merged_masks = ['dummy']
        while len(merged_masks) > 0:
            res = merging.merge_components(mmap_scan, initial_A, initial_b, initial_C,
                                           initial_f, initial_C, pixels_noise,
                                           {'p': 0, 'method': 'cvxpy'}, spatial_params='UNUSED',
                                           dview=pool, thr=merge_threshold, mx=np.Inf)
            initial_A, initial_C, num_components, merged_masks, S, bl, c1, neurons_noise, g = res

        # Delete log files (one per patch)
        log_files = glob.glob('caiman*_LOG_*')
        for log_file in log_files:
            os.remove(log_file)
    else:
        from scipy.sparse import csr_matrix
        if init_method == 'greedy_roi':
            res = _greedyROI(scan, num_components, soma_diameter, num_background_components)
            log('Refining initial components (HALS)...')
            res = initialization.hals(scan, res[0].reshape([image_height * image_width, -1], order='F'),
                                      res[1], res[2].reshape([image_height * image_width, -1], order='F'),
                                      res[3], maxIter=3)
            initial_A, initial_C, initial_b, initial_f = res
        else:
            print('Warning: Running sparse_nmf initialization on the entire field of view '
                  'takes a lot of time.')
            res = initialization.initialize_components(scan, K=num_components, nb=num_background_components,
                                                       method=init_method, alpha_snmf=snmf_alpha)
            initial_A, initial_C, initial_b, initial_f, _ = res
        initial_A = csr_matrix(initial_A)
    log(initial_A.shape[-1], 'components found...')

    # Remove bad components (based on spatial consistency and spiking activity)
    log('Removing bad components...')
    good_indices, _ = components_evaluation.estimate_components_quality(initial_C, scan,
        initial_A, initial_C, initial_b, initial_f, final_frate=fps, r_values_min=0.7,
        fitness_min=-20, fitness_delta_min=-20, dview=pool)
    initial_A = initial_A[:, good_indices]
    initial_C = initial_C[good_indices]
    log(initial_A.shape[-1], 'components remaining...')

    # Estimate noise per pixel
    log('Calculating noise per pixel...')
    pixels_noise, _ = pre_processing.get_noise_fft_parallel(mmap_scan, num_pixels_per_process, pool)

    # Update masks
    log('Updating masks...')
    A, b, C, f = spatial.update_spatial_components(mmap_scan, initial_C, initial_f, initial_A, b_in=initial_b,
                                                   sn=pixels_noise, dims=(image_height, image_width),
                                                   method='dilate', dview=pool,
                                                   n_pixels_per_process=num_pixels_per_process,
                                                   nb=num_background_components)

    # Update traces (no impulse response modelling p=0)
    log('Updating traces...')
    res = temporal.update_temporal_components(mmap_scan, A, b, C, f, nb=num_background_components,
                                              block_size=10000, p=0, method='cvxpy', dview=pool)
    C, A, b, f, S, bl, c1, neurons_noise, g, YrA, _ = res


    # Merge components
    log('Merging overlapping (and temporally correlated) masks...')
    merged_masks = ['dummy']
    while len(merged_masks) > 0:
        res = merging.merge_components(mmap_scan, A, b, C, f, S, pixels_noise, {'p': 0, 'method': 'cvxpy'},
                                       'UNUSED', dview=pool, thr=merge_threshold, bl=bl, c1=c1,
                                       sn=neurons_noise, g=g)
        A, C, num_components, merged_masks, S, bl, c1, neurons_noise, g = res

    # Refine masks
    log('Refining masks...')
    A, b, C, f = spatial.update_spatial_components(mmap_scan, C, f, A, b_in=b, sn=pixels_noise,
                                                   dims=(image_height, image_width),
                                                   method='dilate', dview=pool,
                                                   n_pixels_per_process=num_pixels_per_process,
                                                   nb=num_background_components)

    # Refine traces
    log('Refining traces...')
    res = temporal.update_temporal_components(mmap_scan, A, b, C, f, nb=num_background_components,
                                              block_size=10000, p=0, method='cvxpy', dview=pool)
    C, A, b, f, S, bl, c1, neurons_noise, g, YrA, _ = res

    # Removing bad components (more stringent criteria)
    log('Removing bad components...')
    good_indices, _ = components_evaluation.estimate_components_quality(C + YrA, scan, A,
        C, b, f, final_frate=fps, r_values_min=0.8, fitness_min=-40, fitness_delta_min=-40,
        dview=pool)
    A = A.toarray()[:, good_indices]
    C = C[good_indices]
    YrA = YrA[good_indices]
    log(A.shape[-1], 'components remaining...')

    # Stop processes
    log('Done.')
    pool.close()

    # Get results
    masks = A.reshape((image_height, image_width, -1), order='F') # h x w x num_components
    traces = C  # num_components x num_frames
    background_masks = b.reshape((image_height, image_width, -1), order='F') # h x w x num_components
    background_traces = f  # num_background_components x num_frames
    raw_traces = C + YrA  # num_components x num_frames

    # Rescale traces to match scan range
    scaling_factor = np.sum(masks**2, axis=(0, 1)) / np.sum(masks, axis=(0, 1))
    traces = traces * np.expand_dims(scaling_factor, -1)
    raw_traces = raw_traces * np.expand_dims(scaling_factor, -1)
    masks = masks / scaling_factor
    background_scaling_factor = np.sum(background_masks**2, axis=(0, 1)) / np.sum(background_masks,
                                                                                  axis=(0,1))
    background_traces = background_traces * np.expand_dims(background_scaling_factor, -1)
    background_masks = background_masks / background_scaling_factor

    return masks, traces, background_masks, background_traces, raw_traces
コード例 #3
0
ファイル: cnmf.py プロジェクト: neuron-glia-lab/CaImAn
    def fit(self, images):
        """
        This method uses the cnmf algorithm to find sources in data.

        Parameters
        ----------
        images : mapped np.ndarray of shape (t,x,y[,z]) containing the images that vary over time.

        Returns
        --------
        self

        """
        T = images.shape[0]
        dims = images.shape[1:]
        Yr = images.reshape([T, np.prod(dims)], order='F').T
        Y = np.transpose(images, list(range(1, len(dims) + 1)) + [0])
        print((T,) + dims)

        options = CNMFSetParms(Y, self.n_processes, p=self.p, gSig=self.gSig, K=self.k, ssub=self.ssub, tsub=self.tsub,
                               p_ssub=self.p_ssub, p_tsub=self.p_tsub, method_init=self.method_init,
                               n_pixels_per_process=self.n_pixels_per_process, block_size=self.block_size, check_nan=self.check_nan)

        self.options = options

        if self.rf is None:  # no patches
            print('preprocessing ...')

            Yr, sn, g, psx = preprocess_data(Yr, dview=self.dview, **options['preprocess_params'])

            if self.Ain is None:
                print('initializing ...')
                if self.alpha_snmf is not None:
                    options['init_params']['alpha_snmf'] = self.alpha_snmf

                self.Ain, self.Cin, self.b_in, self.f_in, center = initialize_components(
                    Y, normalize=True, **options['init_params'])
                
            if self.only_init: # only return values after initialization
                
                nA = np.squeeze(np.array(np.sum(np.square(self.Ain),axis=0)))
        
                nr=nA.size
                Cin=scipy.sparse.coo_matrix(self.Cin)
                
        
                YA = (self.Ain.T.dot(Yr).T)*scipy.sparse.spdiags(old_div(1.,nA),0,nr,nr)
                AA = ((self.Ain.T.dot(self.Ain))*scipy.sparse.spdiags(old_div(1.,nA),0,nr,nr))

                self.YrA = YA - Cin.T.dot(AA)
                self.C = Cin.todense()           
                
                self.bl = None
                self.c1 = None
                self.neurons_sn = None
                self.g = g
                self.A = self.Ain                  
                self.b = self.b_in
                self.f = self.f_in
                self.sn = sn
                
                return self

                
            print('update spatial ...')
            A, b, Cin, self.f_in = update_spatial_components(Yr, self.Cin, self.f_in, self.Ain, sn=sn, dview=self.dview, **options['spatial_params'])

            print('update temporal ...')
            if not self.skip_refinement:
                # set this to zero for fast updating without deconvolution
                options['temporal_params']['p'] = 0
            else:
                options['temporal_params']['p'] = self.p

            options['temporal_params']['method'] = self.method_deconvolution

            C, f, S, bl, c1, neurons_sn, g, YrA = update_temporal_components(
                Yr, A, b, Cin, self.f_in, dview=self.dview, **options['temporal_params'])

            if not self.skip_refinement:

                g1 = g 

                for _ in range(self.N_iterations_refinement):

                    if self.do_merge:
                        print('merge components ...')
                        A, C, nr, merged_ROIs, S, bl, c1, sn1, g1 = merge_components(Yr, A, b, C, f, S, sn, options['temporal_params'], options[
                                                                                 'spatial_params'], dview=self.dview, bl=bl, c1=c1, sn=neurons_sn, g=g1, thr=self.merge_thresh, mx=50, fast_merge=True)

                    print((A.shape))

                    print('update spatial ...')

                    A, b, C, f = update_spatial_components(
                        Yr, C, f, A, sn=sn, dview=self.dview, **options['spatial_params'])

                    # set it back to original value to perform full deconvolution
                    options['temporal_params']['p'] = self.p
                    print('update temporal ...')
                    C, f, S, bl, c1, neurons_sn, g1, YrA = update_temporal_components(
                        Yr, A, b, C, f, dview=self.dview, bl=None, c1=None, sn=None, g=None, **options['temporal_params'])

            else:

                A, b, C = A, b, Cin
                C, f, S, bl, c1, neurons_sn, g1, YrA = C, f, S, bl, c1, neurons_sn, g, YrA

        else:  # use patches

            if self.stride is None:
                self.stride = np.int(self.rf * 2 * .1)
                print(('**** Setting the stride to 10% of 2*rf automatically:' + str(self.stride)))

            if type(images) is np.ndarray:
                raise Exception(
                    'You need to provide a memory mapped file as input if you use patches!!')

            if self.only_init:
                options['patch_params']['only_init'] = True

            if self.alpha_snmf is not None:
                options['init_params']['alpha_snmf'] = self.alpha_snmf

            A, C, YrA, b, f, sn, optional_outputs = run_CNMF_patches(images.filename, dims + (T,), options, rf=self.rf, stride=self.stride,
                                                                     dview=self.dview, memory_fact=self.memory_fact, gnb=self.gnb)

            options = CNMFSetParms(Y, self.n_processes, p=self.p, gSig=self.gSig, K=A.shape[
                                   -1], thr=self.merge_thresh, n_pixels_per_process=self.n_pixels_per_process, block_size=self.block_size, check_nan=self.check_nan)

#            pix_proc=np.minimum(np.int((d1*d2)/self.n_processes/(old_div(T,2000.))),np.int(old_div((d1*d2),self.n_processes))) # regulates the amount of memory used
#            options['spatial_params']['n_pixels_per_process']=pix_proc
#            options['temporal_params']['n_pixels_per_process']=pix_proc
            options['temporal_params']['method'] = self.method_deconvolution

            print("merging")
            merged_ROIs = [0]
            while len(merged_ROIs) > 0:
                A, C, nr, merged_ROIs, S, bl, c1, sn_n, g = merge_components(Yr, A, [], np.array(C), [], np.array(
                    C), [], options['temporal_params'], options['spatial_params'], dview=self.dview, thr=self.merge_thresh, mx=np.Inf)

            print("update temporal")
            C, f, S, bl, c1, neurons_sn, g1, YrA = update_temporal_components(
                Yr, A, b, C, f, dview=self.dview, bl=None, c1=None, sn=None, g=None, **options['temporal_params'])

#           idx_components, fitness, erfc ,r_values, num_significant_samples = evaluate_components(Y,C+YrA,A,N=self.N_samples_fitness,robust_std=self.robust_std,thresh_finess=self.fitness_threshold)
#           sure_in_idx= idx_components[np.logical_and(np.array(num_significant_samples)>0 ,np.array(r_values)>=self.corr_threshold)]
#
#           print ('Keeping ' + str(len(sure_in_idx)) + ' components out of ' + str(len(idx_components)))
#
#
#           A=A[:,sure_in_idx]
#           C=C[sure_in_idx,:]
#           YrA=YrA[sure_in_idx]

        self.A=A
        self.C=C
        self.b=b
        self.f=f
        self.S = S
        self.YrA=YrA
        self.sn=sn
        self.g = g1
        self.bl = bl
        self.c1 = c1
        self.neurons_sn = neurons_sn

        return self