예제 #1
0
파일: workers.py 프로젝트: iawn/live2p
    def __init__(self, files, plane, nchannels, nplanes, params):
        """
        Base class for implementing live2p. Don't call this class directly, rather call or
        make a subclass.

        Args:
            files (list): list of files to process
            plane (int): plane number to process, serves a slice through each tiff
            nchannels (int): total number of channels total, helps slicing ScanImage tiffs
            nplanes (int): total number of z-planes imaged, helps slicing ScanImage tiffs
            params (dict): caiman params dict
        """
        self.files = files
        self.plane = plane
        self.nchannels = nchannels
        self.nplanes = nplanes

        # setup the params object
        logger.debug('Setting up params...')
        self._params = CNMFParams(params_dict=params)

        self.data_root = Path(self.files[0]).parent
        self.caiman_path = Path()
        self.temp_path = Path()
        self.all_path = Path()
        self._setup_folders()

        # these get set up by _start_cluster, called on run so workers can be queued w/o
        # ipyparallel clusters clashing
        self.c = None  # little c is ipyparallel related
        self.dview = None
        self.n_processes = None
예제 #2
0
파일: workers.py 프로젝트: iawn/live2p
 def params(self, params):
     if isinstance(params, dict):
         self._params = CNMFParams(params_dict=params)
     elif isinstance(params, CNMFParams):
         self._params = params
     else:
         raise ValueError('Please supply a dict or cnmf params object.')
예제 #3
0
    def load_params_CNMF(self, file, is_patch):

        opts_dict = {
            'tsub': self.get_dict_param('tsub', 'single_int'),
            'ssub': self.get_dict_param('ssub', 'single_int'),
            'fnames': file,
            'decay_time': self.get_dict_param('decay_time', 'single_float'),
            'fr': self.get_dict_param('frate', 'single_float'),
            'nb': self.get_dict_param('gnb', 'single_int'),
            'gSig': self.get_dict_param('gSig', 'tuple_int'),
            'method_init': self.get_dict_param('method_init', 'str'),
            'rolling_sum': True,
            'merge_thr': self.get_dict_param('merge_thresh', 'single_float'),
            'n_processes': self.get_dict_param('n_processes_cnmf',
                                               'single_int'),
        }

        if is_patch:
            opts_dict['K'] = self.get_dict_param('k_patch', 'single_int')
            opts_dict['rf'] = self.get_dict_param('rf', 'single_int'),
            opts_dict['stride'] = self.get_dict_param('stride', 'single_int'),
        else:
            opts_dict['K'] = self.get_dict_param('k', 'single_int')

        return CNMFParams(params_dict=opts_dict)
예제 #4
0
    def setup(self):
        ''' Create OnACID object and initialize it
                (runs initialize online)
        '''
        logger.info('Running setup for ' + self.name)
        self.done = False
        self.dropped_frames = []
        self.coords = None
        self.ests = None
        self.A = None
        self.saving = True

        self.loadParams(param_file=self.param_file)
        self.params = self.client.get('params_dict')

        # MUST include inital set of frames
        # TODO: Institute check here as requirement to Nexus
        print(self.params['fnames'])

        self.opts = CNMFParams(params_dict=self.params)
        self.onAc = OnACID(params=self.opts)
        #TODO: Need to rewrite init online as well to receive individual frames.
        self.onAc.initialize_online()
        self.max_shifts_online = self.onAc.params.get('online',
                                                      'max_shifts_online')
예제 #5
0
    def setup(self):
        ''' Using #2 method from the realtime demo, with short init
            and online processing with OnACID-E
        '''
        logger.info('Running setup for '+self.name)
        self.done = False
        self.dropped_frames = []
        self.coords = None
        self.ests = None
        self.A = None
        self.num = 0
        self.saving = False

        self.loadParams(param_file=self.param_file)
        self.params = self.client.get('params_dict')

        # MUST include inital set of frames
        print(self.params['fnames'])

        self.opts = CNMFParams(params_dict=self.params)
        self.onAc = OnACID(params = self.opts)
        self.onAc.initialize_online()
        self.max_shifts_online = self.onAc.params.get('online', 'max_shifts_online')
    def fit(self, images):
        """
        This method uses the cnmf algorithm to find sources in data.

        Args:
            images : mapped np.ndarray of shape (t,x,y) containing the images that vary over time.

        Returns:
            self 
        """

        T, d1, d2 = images.shape
        dims = (d1, d2)
        Yr = images.reshape([T, np.prod(dims)], order='F').T
        Y = np.transpose(images, [1, 2, 0])
        print((T, d1, d2))

        options = CNMFParams(dims, K=self.k, gSig=self.gSig, ssub=self.ssub, tsub=self.tsub, p=self.p,
                             p_ssub=self.p_ssub, p_tsub=self.p_tsub, method_init=self.method_init, normalize_init=True)

        self.options = options

        if self.rf is None:

            Yr, sn, g, psx = preprocess_data(
                Yr, dview=self.dview, **options['preprocess_params'])

            if self.Ain is None:
                if self.alpha_snmf is not None:
                    options['init_params']['alpha_snmf'] = self.alpha_snmf

                self.Ain, self.Cin, self.b_in, self.f_in, center = initialize_components(
                    Y, **options['init_params'])

            A, b, Cin, self.f_in = update_spatial_components(Yr, self.Cin, self.f_in, self.Ain, sn=sn,
                                                             dview=self.dview, **options['spatial_params'])

            # set this to zero for fast updating without deconvolution
            options['temporal_params']['p'] = 0

            C, A, b, f, S, bl, c1, neurons_sn, g, YrA = update_temporal_components(Yr, A, b, Cin, self.f_in,
                                                                                   dview=self.dview, **options['temporal_params'])

            if self.do_merge:
                A, C, nr, merged_ROIs, S, bl, c1, sn1, g1 = merge_components(Yr, A, b, C, f, S, sn, options['temporal_params'],
                                                                             options['spatial_params'], dview=self.dview, bl=bl, c1=c1, sn=neurons_sn, g=g,
                                                                             thr=self.merge_thresh, mx=50, fast_merge=True)

            print((A.shape))

            A, b, C, f = update_spatial_components(
                Yr, C, f, A, sn=sn, dview=self.dview, dims=self.dims,  **options['spatial_params'])
            # set it back to original value to perform full deconvolution
            options['temporal_params']['p'] = self.p

            C, A, b, f, S, bl, c1, neurons_sn, g1, YrA = update_temporal_components(Yr, A, b, C, f,
                                                                                    dview=self.dview, bl=None, c1=None, sn=None, g=None, **options['temporal_params'])

        else:  # use patches
            if self.stride is None:
                self.stride = np.int(self.rf * 2 * .1)
                print(
                    ('**** Setting the stride to 10% of 2*rf automatically:' + str(self.stride)))

            if type(images) is np.ndarray:
                raise Exception(
                    'You need to provide a memory mapped file as input if you use patches!!')

            if self.only_init:
                options['patch_params']['only_init'] = True

            A, C, YrA, b, f, sn, optional_outputs = run_CNMF_patches(images.filename, (d1, d2, T), options, rf=self.rf,
                                                                     stride=self.stride,
                                                                     dview=self.dview, memory_fact=self.memory_fact)

            self.optional_outputs = optional_outputs

            options = CNMFParams(dims, K=A.shape[-1], gSig=self.gSig, p=self.p, thr=self.merge_thresh)
            pix_proc = np.minimum(np.int((d1 * d2) / self.n_processes / (
                old_div(T, 2000.))), np.int(old_div((d1 * d2), self.n_processes)))  # regulates the amount of memory used

            options['spatial_params']['n_pixels_per_process'] = pix_proc
            options['temporal_params']['n_pixels_per_process'] = pix_proc
            merged_ROIs = [0]
            self.merged_ROIs = []
            while len(merged_ROIs) > 0:
                A, C, nr, merged_ROIs, S, bl, c1, sn, g = merge_components(Yr, A, [],
                                                                           np.array(C), [], np.array(
                                                                               C), [], options['temporal_params'], options['spatial_params'],
                                                                           dview=self.dview, thr=self.merge_thresh, mx=np.Inf)

                self.merged_ROIs.append(merged_ROIs)

            C, A, b, f, S, bl, c1, neurons_sn, g2, YrA = update_temporal_components(
                Yr, A, np.atleast_2d(b).T, C, f, dview=self.dview, bl=None, c1=None, sn=None, g=None, **options['temporal_params'])

        self.A = A
        self.C = C
        self.b = b
        self.f = f
        self.YrA = YrA
        self.sn = sn

        return self
예제 #7
0
def run(batch_dir: str, UUID: str):
    logging.basicConfig(
        stream=sys.stdout,
        level=logging.DEBUG,
        format=
        "%(relativeCreated)12d [%(filename)s:%(funcName)20s():%(lineno)s] [%(process)d] %(message)s"
    )

    start_time = time()

    output = {'status': 0, 'output_info': ''}
    n_processes = os.environ['_MESMERIZE_N_THREADS']
    n_processes = int(n_processes)
    file_path = os.path.join(batch_dir, UUID)

    filename = [file_path + '_input.tiff']
    input_params = pickle.load(open(file_path + '.params', 'rb'))

    # If Ain is specified
    if input_params['do_cnmfe']:
        Ain = None
        item_uuid = input_params['cnmfe_kwargs'].pop('Ain')

        if item_uuid:
            print('>> Ain specified, looking for cnm-A file <<')
            parent_batch_dir = os.environ['CURR_BATCH_DIR']
            item_out_file = os.path.join(parent_batch_dir, f'{item_uuid}.out')
            t0 = time()
            timeout = 60
            while not os.path.isfile(item_out_file):
                print('>>> cnm-A not found, waiting for 15 seconds <<<')
                sleep(15)
                if time() - t0 > timeout:
                    output.update({
                        'status':
                        0,
                        'output_info':
                        'Timeout exceeding in waiting for Ain input file'
                    })
                    raise TimeoutError(
                        'Timeout exceeding in waiting for Ain input file')

            if os.path.isfile(item_out_file):
                if json.load(open(item_out_file, 'r'))['status']:
                    Ain_path = os.path.join(parent_batch_dir,
                                            item_uuid + '_results.hdf5')
                    Ain = load_dict_from_hdf5(Ain_path)['estimates']['A']
                    print('>>> Found Ain file <<<')
                else:
                    raise FileNotFoundError(
                        '>>> Could not find specified Ain file <<<')

    print('*********** Creating Process Pool ***********')

    c, dview, n_processes = cm.cluster.setup_cluster(backend='local',
                                                     n_processes=n_processes,
                                                     single_thread=False,
                                                     ignore_preexisting=True)

    try:
        print('Creating memmap')

        memmap_path = cm.save_memmap(
            filename,
            base_name=f'memmap-{UUID}',
            order='C',
            dview=dview,
            border_to_0=input_params['border_pix'],
        )

        Yr, dims, T = cm.load_memmap(memmap_path)
        Y = Yr.T.reshape((T, ) + dims, order='F')

        if input_params['do_cnmfe']:
            gSig = input_params['cnmfe_kwargs']['gSig'][0]
        else:
            gSig = input_params['corr_pnr_kwargs']['gSig']

        cn_filter, pnr = cm.summary_images.correlation_pnr(Y,
                                                           swap_dim=False,
                                                           gSig=gSig)

        if not input_params['do_cnmfe'] and input_params['do_corr_pnr']:
            pickle.dump(cn_filter,
                        open(UUID + '_cn_filter.pikl', 'wb'),
                        protocol=4)
            pickle.dump(pnr, open(UUID + '_pnr.pikl', 'wb'), protocol=4)

            output_file_list = \
                [
                    UUID + '_pnr.pikl',
                    UUID + '_cn_filter.pikl',
                ]

            output.update({
                'output': UUID,
                'status': 1,
                'output_info': 'inspect correlation & pnr',
                'output_files': output_file_list
            })

            dview.terminate()

            for mf in glob(batch_dir + '/memmap-*'):
                try:
                    os.remove(mf)
                except:  # Windows doesn't like removing the memmaps
                    pass

            end_time = time()
            processing_time = (end_time - start_time) / 60
            output.update({'processing_time': processing_time})

            json.dump(output, open(file_path + '.out', 'w'))

            return

        cnmf_params_dict = \
            {
                "method_init": 'corr_pnr',
                "n_processes": n_processes,
                "only_init_patch": True,  # for 1p
                "center_psf": True,  # for 1p
                "normalize_init": False,  # for 1p
            }
        cnmf_params_dict.update(**input_params['cnmfe_kwargs'])

        cnm = cnmf.CNMF(
            n_processes=n_processes,
            dview=dview,
            Ain=Ain,
            params=CNMFParams(params_dict=cnmf_params_dict),
        )

        cnm.fit(Y)

        #  DISCARD LOW QUALITY COMPONENTS
        cnm.params.set('quality', {
            'use_cnn': False,
            **input_params['eval_kwargs']
        })

        cnm.estimates.evaluate_components(Y, cnm.params, dview=dview)

        out_filename = f'{UUID}_results.hdf5'
        cnm.save(out_filename)

        pickle.dump(pnr, open(UUID + '_pnr.pikl', 'wb'), protocol=4)
        pickle.dump(cn_filter,
                    open(UUID + '_cn_filter.pikl', 'wb'),
                    protocol=4)

        output.update({
            'output':
            filename[:-5],
            'status':
            1,
            'output_files': [
                out_filename,
                UUID + '_pnr.pikl',
                UUID + '_cn_filter.pikl',
            ]
        })

    except Exception as e:
        output.update({
            'status': 0,
            'Y.shape': Y.shape,
            'output_info': traceback.format_exc()
        })

    dview.terminate()

    for mf in glob(batch_dir + '/memmap-*'):
        try:
            os.remove(mf)
        except:  # Windows doesn't like removing the memmaps
            pass

    end_time = time()
    processing_time = (end_time - start_time) / 60
    output.update({'processing_time': processing_time})

    json.dump(output, open(file_path + '.out', 'w'))
예제 #8
0
def run_cnmfe(tiff_files, param_file, output_file):
    """ Run the CNMFe algorithm through CaImAn.

    :param tiff_files: A list of .tiff files corresponding to a calcium imaging movie.
    :param param_file: A .yaml parameter file, containing values for the following parameters:
        num_processes : int
            The number of processes to run in parallel. The more parallel processes, the more memory that is used.
        rf : array-like
            An array [half-width, half-height] that specifies the size of a patch.
        stride : int
            The amount of overlap in pixels between patches.
        K : int
            The maximum number of cells per patch.
        gSiz : int
            The expected diameter of a neuron in pixels.
        gSig : int
            The standard deviation a high pass Gaussian filter applied to the movie prior to seed pixel search, roughly
            equal to the half-size of the neuron in pixels.
        min_pnr : float
            The minimum peak-to-noise ratio that is taken into account when searching for seed pixels.
        min_corr : float
            The minimum pixel correlation that is taken into account when searching for seed pixels.
        min_SNR : float
            Cells with an signal-to-noise (SNR) less than this are rejected.
        rval_thr : float
            Cells with a spatial correlation of greater than this are accepted.
        decay_time : float
            The expected decay time of a calcium event in seconds.
        ssub_B : int
            The spatial downsampling factor used on the background term.
        merge_threshold : float
            Cells that are spatially close with a temporal correlation of greater than merge_threshold are automatically merged.
    :param output_file: The path to a .hdf5 file that will be written to contain the traces, footprints, and deconvolved
        events identified by CNMFe.
    """

    for tiff_file in tiff_files:
        if not os.path.exists(tiff_file):
            raise FileNotFoundError(tiff_file)

    if not os.path.exists(param_file):
        raise FileNotFoundError(param_file)

    with open(param_file, 'r') as f:
        params = yaml.load(f)

    expected_params = [
        'gSiz', 'gSig', 'K', 'min_corr', 'min_pnr', 'rf', 'stride',
        'decay_time', 'min_SNR', 'rval_thr', 'merge_threshold', 'ssub_B',
        'frame_rate', 'num_rows', 'num_cols', 'num_frames', 'num_processes'
    ]

    for pname in expected_params:
        if pname not in params:
            raise ValueError('Missing parameter {} in file {}'.format(
                pname, param_file))

    gSiz = params['gSiz']
    gSig = params['gSig']
    K = params['K']
    min_corr = params['min_corr']
    min_pnr = params['min_pnr']
    rf = params['rf']
    stride = params['stride']
    decay_time = params['decay_time']
    min_SNR = params['min_SNR']
    rval_thr = params['rval_thr']
    merge_threshold = params['merge_threshold']
    ssub_B = params['ssub_B']
    frame_rate = params['frame_rate']
    num_rows = params['num_rows']
    num_cols = params['num_cols']
    num_frames = params['num_frames']
    num_processes = params['num_processes']

    # write memmapped file
    print('Exporting .isxd to memmap file...')
    mmap_file = _export_movie_to_memmap(tiff_files,
                                        num_frames,
                                        num_rows,
                                        num_cols,
                                        overwrite=False)
    print('Wrote .mmap file to: {}'.format(mmap_file))

    # open memmapped file
    Yr, dims, T = load_memmap(mmap_file)
    Y = Yr.T.reshape((T, ) + dims, order='F')

    # grab parallel IPython handle
    dview = None
    if num_processes > 1:
        import ipyparallel as ipp
        c = ipp.Client()
        dview = c[:]
        print('Running using parallel IPython, # clusters = {}'.format(
            len(c.ids)))
        num_processes = len(c.ids)

    # initialize CNMFE parameter object and set user params
    cnmfe_params = CNMFParams()

    if gSiz is None:
        raise ValueError(
            'You must set gSiz to an integer, ideally roughly equal to the expected half-cell width.'
        )
    gSiz = _turn_into_array(gSiz)

    if gSig is None:
        raise ValueError(
            'You must set gSig to a non-zero integer. The default value is 5.')
    gSig = _turn_into_array(gSig)

    cnmfe_params.set('preprocess', {'p': 1})

    cnmfe_params.set(
        'init', {
            'K': K,
            'min_corr': min_corr,
            'min_pnr': min_pnr,
            'gSiz': gSiz,
            'gSig': gSig
        })

    if rf is None:
        cnmfe_params.set('patch', {'rf': None, 'stride': 1})
    else:
        cnmfe_params.set('patch', {'rf': np.array(rf), 'stride': stride})

    cnmfe_params.set('data', {'decay_time': decay_time})

    cnmfe_params.set('quality', {'min_SNR': min_SNR, 'rval_thr': rval_thr})

    cnmfe_params.set('merging', {'merge_thr': merge_threshold})

    # set parameters that force CNMF into one-photon mode with no temporal or spatial downsampling,
    # except for the background term
    cnmfe_params.set(
        'init', {
            'center_psf': True,
            'method_init': 'corr_pnr',
            'normalize_init': False,
            'nb': -1,
            'ssub_B': ssub_B,
            'tsub': 1,
            'ssub': 1
        })
    cnmfe_params.set(
        'patch', {
            'only_init': True,
            'low_rank_background': None,
            'nb_patch': -1,
            'p_tsub': 1,
            'p_ssub': 1
        })
    cnmfe_params.set('spatial', {
        'nb': -1,
        'update_background_components': False
    })
    cnmfe_params.set('temporal', {'nb': -1, 'p': 1})

    # construct and run CNMFE
    print('Running CNMFe...')
    cnmfe = CNMF(num_processes, dview=dview, params=cnmfe_params)
    cnmfe.fit(Y)

    # run auto accept/reject
    print('Estimating component quality...')
    idx_components, idx_components_bad, comp_SNR, r_values, pred_CNN = estimate_components_quality_auto(
        Y,
        cnmfe.estimates.A,
        cnmfe.estimates.C,
        cnmfe.estimates.b,
        cnmfe.estimates.f,
        cnmfe.estimates.YrA,
        frame_rate,
        cnmfe_params.get('data', 'decay_time'),
        cnmfe_params.get('init', 'gSiz'),
        cnmfe.dims,
        dview=None,
        min_SNR=cnmfe_params.get('quality', 'min_SNR'),
        use_cnn=False)

    save_cnmfe(cnmfe, output_file, good_idx=idx_components)
예제 #9
0
def OnACID_A_init(fr, fnames, out, hfile, epochs=2):

    # %% set up some parameters

    decay_time = .4  # approximate length of transient event in seconds
    gSig = (4, 4)  # expected half size of neurons
    p = 1  # order of AR indicator dynamics
    thresh_CNN_noisy = 0.8  #0.65  # CNN threshold for candidate components
    gnb = 2  # number of background components
    init_method = 'cnmf'  # initialization method
    min_SNR = 2.5  # signal to noise ratio for accepting a component
    rval_thr = 0.8  # space correlation threshold for accepting a component
    ds_factor = 1  # spatial downsampling factor, newImg=img/ds_factor(increases speed but may lose some fine structure)

    # K = 25  # number of components per patch
    patch_size = 32  # size of patch
    stride = 3  # amount of overlap between patches

    max_num_added = 5
    max_comp_update_shape = np.inf
    update_num_comps = False

    gSig = tuple(np.ceil(
        np.array(gSig) /
        ds_factor).astype('int'))  # recompute gSig if downsampling is involved
    mot_corr = True  # flag for online motion correction
    pw_rigid = False  # flag for pw-rigid motion correction (slower but potentially more accurate)
    max_shifts_online = np.ceil(10. / ds_factor).astype(
        'int')  # maximum allowed shift during motion correction
    sniper_mode = False  # use a CNN to detect new neurons (o/w space correlation)
    # set up some additional supporting parameters needed for the algorithm
    # (these are default values but can change depending on dataset properties)
    init_batch = 500  # number of frames for initialization (presumably from the first file)
    K = 2  # initial number of components
    show_movie = False  # show the movie as the data gets processed
    print("Frame rate: {}".format(fr))
    params_dict = {
        'fr': fr,
        'fnames': fnames,
        'decay_time': decay_time,
        'gSig': gSig,
        'gnb': gnb,
        'p': p,
        'min_SNR': min_SNR,
        'rval_thr': rval_thr,
        'ds_factor': ds_factor,
        'nb': gnb,
        'motion_correct': mot_corr,
        'normalize': True,
        'sniper_mode': sniper_mode,
        'K': K,
        'use_cnn': False,
        'epochs': epochs,
        'max_shifts_online': max_shifts_online,
        'pw_rigid': pw_rigid,
        'min_num_trial': 10,
        'show_movie': show_movie,
        'save_online_movie': False,
        "max_num_added": max_num_added,
        "max_comp_update_shape": max_comp_update_shape,
        "update_num_comps": update_num_comps,
        "dist_shape_update": update_num_comps,
        'init_batch': init_batch,
        'init_method': init_method,
        'rf': patch_size // 2,
        'stride': stride,
        'thresh_CNN_noisy': thresh_CNN_noisy
    }
    opts = CNMFParams(params_dict=params_dict)
    with h5py.File(hfile, 'r') as hf:
        ests = Estimates(A=load_A(hf))
    cnm = online_cnmf.OnACID(params=opts, estimates=ests)
    cnm.estimates = ests
    cnm.fit_online()
    cnm.save(out)
예제 #10
0
def run_single(batch_dir, UUID, output):
    n_processes = os.environ['_MESMERIZE_N_THREADS']
    n_processes = int(n_processes)
    file_path = os.path.join(batch_dir, UUID)

    filename = [file_path + '_input.tiff']
    input_params = pickle.load(open(file_path + '.params', 'rb'))

    print('*********** Creating Process Pool ***********')
    c, dview, n_processes = cm.cluster.setup_cluster(
        backend='local', n_processes=n_processes, single_thread=False, ignore_preexisting=True
    )

    memmap_fname = cm.save_memmap(
        filename,
        base_name=f'memmap-{UUID}',
        order='C',
        border_to_0=input_params['border_pix'],
        dview=dview
    )

    Yr, dims, T = cm.load_memmap(memmap_fname)
    Y = np.reshape(Yr.T, [T] + list(dims), order='F')

    Ain = None

    # seed components
    if 'use_seeds' in input_params.keys():
        if input_params['use_seeds']:
            try:
                # see if it's an h5 file produced by the nuset_segment GUI
                hdict = HdfTools.load_dict(
                    os.path.join(f'{file_path}.ain'),
                    'data'
                )
                Ain = hdict['sparse_mask']
            except:
                try:
                    Ain = np.load(f'{file_path}.ain')
                except Exception as e:
                    output['warnings'] = f'Could not seed components, make sure that ' \
                        f'the .ain file exists in the batch dir: {e}'

    # seeded
    if Ain is not None:
        input_params['cnmf_kwargs'].update(
            {
                'only_init': False,
                'rf': None
            }
        )

    cnmf_params = CNMFParams(params_dict=input_params['cnmf_kwargs'])

    cnm = cnmf.CNMF(
        dview=dview,
        n_processes=n_processes,
        Ain=Ain,
        params=cnmf_params,
    )

    cnm.fit(Y)

    if input_params['refit']:
        cnm = cnm.refit(Y, dview=dview)

    cnm.params.change_params(params_dict=input_params['eval_kwargs'])

    cnm.estimates.evaluate_components(
        Y,
        cnm.params,
        dview=dview
    )

    cnm.estimates.select_components(use_object=True)

    out_filename = f'{UUID}_results.hdf5'
    cnm.save(out_filename)

    output_files = [out_filename]

    output.update(
        {
            'output': UUID,
            'status': 1,
            'output_files': output_files
        }
    )

    dview.terminate()

    return output
예제 #11
0
def run_multi(batch_dir, UUID, output):
    n_processes = os.environ['_MESMERIZE_N_THREADS']
    n_processes = int(n_processes)
    file_path = os.path.join(batch_dir, UUID)

    filename = [file_path + '_input.tiff']
    input_params = pickle.load(open(file_path + '.params', 'rb'))

    seq = tifffile.TiffFile(filename[0]).asarray()
    seq_shape = seq.shape

    # assume default tzxy
    for z in range(seq.shape[1]):
        tifffile.imsave(f'{file_path}_z{z}.tiff', seq[:, z, :, :])

    del seq

    print('*********** Creating Process Pool ***********')
    c, dview, n_processes = cm.cluster.setup_cluster(
        backend='local', n_processes=n_processes, single_thread=False, ignore_preexisting=True
    )
    num_components = 0
    output_files = []
    for z in range(seq_shape[1]):
        print(f"Plane {z} / {seq_shape[1]}")
        filename = [f'{file_path}_z{z}.tiff']
        print('Creating memmap')

        memmap_fname = cm.save_memmap(
            filename,
            base_name=f'memmap-{UUID}',
            order='C',
            border_to_0=input_params['border_pix'],
            dview=dview
        )

        Yr, dims, T = cm.load_memmap(memmap_fname)
        Y = np.reshape(Yr.T, [T] + list(dims), order='F')

        Ain = None

        # seed components
        # see if it's an h5 file produced by the nuset_segment GUI
        try:
            hdict = HdfTools.load_dict(os.path.join(f'{file_path}.ain'), 'data')
            Ain = hdict[f'sparse_mask'][str(z)]
        except Exception as e:
            output['warnings'] = f'Could not seed components, make sure that ' \
                f'the .ain file exists in the batch dir: {e}'

        #print(Ain)
        #raise Exception

        # seeded
        if Ain is not None:
            input_params['cnmf_kwargs'].update(
                {
                    'only_init': False,
                    'rf': None
                }
            )

        cnmf_params = CNMFParams(params_dict=input_params['cnmf_kwargs'])

        cnm = cnmf.CNMF(
            dview=dview,
            n_processes=n_processes,
            Ain=Ain,
            params=cnmf_params,
        )

        cnm.fit(Y)

        if input_params['refit']:
            cnm = cnm.refit(Y, dview=dview)

        cnm.params.set('quality', input_params['eval_kwargs'])

        cnm.estimates.evaluate_components(
            Y,
            cnm.params,
            dview=dview
        )

        cnm.estimates.select_components(use_object=True)

        num_components += len(cnm.estimates.C)

        out_filename = f'{UUID}_results_z{z}.hdf5'
        cnm.save(out_filename)

        output_files.append(out_filename)

        os.remove(filename[0])

    output.update(
        {
            'output': UUID,
            'status': 1,
            'output_files': output_files,
            'num_components': num_components
        }
    )

    dview.terminate()

    return output