def recalc_metrics(mountoutput_dir, output_dir, raw_data_dir='', firings_in='firings_processed.mda',
                   metrics_to_update='', mv2_file='', firing_rate_thresh=0.01, isolation_thresh=0.95, noise_overlap_thresh=0.03, peak_snr_thresh=1.5, manual_only=True):
    '''used post merging/annealing/curation to recalculate metrics and update tags (both tags based
    on thresholds and any manually added ones, stored in the mv2, which is optional to provide).

    Parameters
    ----------
    mountoutput_dir : str, mountain sort output folder.
    output_dir : str, where the output of metric file will be.
    raw_data_dir : str, (optional) usually the tmp folder.
    firings_in : str, optional, which kind of mda to use. default is "firings_processed.mda"
    metrics_to_update : str, ooutput file name
    mv2_file : str, optional. Path to mv2 file. If provided, manual curation tags will be copied.
    firing_rate_thresh : float, optional, default is 0.01
        Clusters less than the firing rate threshold is excluded (spikes / s )
    isolation_thresh : float, optional, default is 0.95
        Distance to a cluster of noise.
    noise_overlap_thresh : float, optional, default is 0.03
        Fraction of “noise events” in a cluster.
    peak_snr_thresh : float, optional, default is 1.5
    manual_only :bool, optional.
                Setting True won't apply hard threshold and will simply copy tags from mv2 if you supply any.
                Setting False if you want to use default threshold lines to do the tags.

    '''
    ds_params = read_dataset_params(mountoutput_dir)
    f = open(os.path.join(mountoutput_dir, 'pre.mda.prv'), "r")
    prv = json.load(f)
    p = pathlib.Path(prv['original_path'])

    if len(raw_data_dir) == 0:
        timeseries_in = prv['original_path']
    else:
        timeseries_in = os.path.join(raw_data_dir, p.parts[-1])

    print('output_dir', output_dir)
    compute_cluster_metrics(
        timeseries=timeseries_in,
        firings=os.path.join(mountoutput_dir, firings_in),
        metrics_out=os.path.join(output_dir, metrics_to_update),
        samplerate=ds_params['samplerate'])

    add_curation_tags(output_dir,
                      output_dir,
                      firing_rate_thresh=firing_rate_thresh,
                      isolation_thresh=isolation_thresh,
                      noise_overlap_thresh=noise_overlap_thresh,
                      peak_snr_thresh=peak_snr_thresh,
                      metrics_input=metrics_to_update,
                      metrics_output=metrics_to_update,
                      mv2file=mv2_file,
                      manual_only=manual_only)
示例#2
0
def ms4_sort_full(dataset_dir,
                  output_dir,
                  geom=None,
                  adjacency_radius=-1,
                  detect_threshold=3,
                  detect_interval=10,
                  detect_sign=False,
                  num_workers=2,
                  opts=None):
    '''Sort the entire file as one mda

    Parameters
    ----------
    dataset_dir : str
    output_dir : str
    geom : None or list, optional
    adjacency_radius : float, optional
    detect_threshold : float, optional
    detect_interval : int, optional
    detect_sign : bool, optional
    num_workers : int, optional
    opt : dict or None, optional

    '''
    if geom is None:
        geom = []
    if opts is None:
        opts = {}
    # Fetch dataset parameters
    ds_params = read_dataset_params(dataset_dir)

    ms4alg(timeseries=os.path.join(output_dir, 'pre.mda.prv'),
           geom=geom,
           firings_out=os.path.join(output_dir, 'firings_raw.mda'),
           adjacency_radius=adjacency_radius,
           detect_sign=int(detect_sign),
           detect_threshold=detect_threshold,
           detect_interval=detect_interval,
           num_workers=num_workers,
           opts=opts)

    # Compute cluster metrics
    compute_cluster_metrics(timeseries=os.path.join(output_dir, 'pre.mda.prv'),
                            firings=os.path.join(output_dir,
                                                 'firings_raw.mda'),
                            metrics_out=os.path.join(output_dir,
                                                     'metrics_raw.json'),
                            samplerate=ds_params['samplerate'],
                            opts=opts)
def merge_burst_parents(dataset_dir, output_dir):
    '''

    Parameters
    ----------
    dataset_dir : str
    output_dir : str

    '''
    pyms_merge_burst_parents(
        firings=os.path.join(output_dir, 'firings_raw.mda'),
        metrics=os.path.join(output_dir, 'metrics_raw.json'),
        firings_out=os.path.join(output_dir, 'firings_burst_merged.mda'))

    ds_params = read_dataset_params(dataset_dir)
    # Compute cluster metrics
    compute_cluster_metrics(
        timeseries=os.path.join(output_dir, 'pre.mda.prv'),
        firings=os.path.join(output_dir, 'firings_burst_merged.mda'),
        metrics_out=os.path.join(output_dir, 'metrics_merged.json'),
        samplerate=ds_params['samplerate'])
示例#4
0
def ms4_sort_on_segs(dataset_dir,
                     output_dir,
                     geom=None,
                     adjacency_radius=-1,
                     detect_threshold=3.0,
                     detect_interval=10,
                     detect_sign=False,
                     rm_segment_intermediates=True,
                     num_workers=2,
                     opts=None,
                     mda_opts=None):
    '''Sort by timesegments, then join any matching clusters

    Parameters
    ----------
    dataset_dir : str
    output_dir : str
    geom : None or list, optional
    adjacency_radius : float, optional
    detect_threshold : float, optional
    detect_interval : int, optional
    detect_sign : bool, optional
    rm_segment_intermediates : bool, optional
    num_workers : int, optional
    opt : dict or None, optional
    mda_opt : dict or None, optional

    '''
    if geom is None:
        geom = []
    if opts is None:
        opts = {}
    if mda_opts is None:
        mda_opts = {}

    # Fetch dataset parameters
    ds_params = read_dataset_params(dataset_dir)
    has_keys = {'anim', 'date', 'ntrode', 'data_location'}.issubset(mda_opts)

    if has_keys:
        logger.info('Finding list of mda file from mda directories of '
                    f'date:{mda_opts["date"]}, ntrode:{mda_opts["ntrode"]}')
        mda_list = get_mda_list(mda_opts['anim'], mda_opts['date'],
                                mda_opts['ntrode'], mda_opts['data_location'])
        # calculate time_offsets and total_duration
        sample_offsets, total_samples = get_epoch_offsets(
            dataset_dir=dataset_dir, opts={'mda_list': mda_list})

    else:
        # calculate time_offsets and total_duration
        sample_offsets, total_samples = get_epoch_offsets(
            dataset_dir=dataset_dir)

    # break up preprocesed data into segments and sort each
    firings_list = []
    timeseries_list = []
    for segind in range(len(sample_offsets)):
        t1 = math.floor(sample_offsets[segind])
        if segind == len(sample_offsets) - 1:
            t2 = total_samples - 1
        else:
            t2 = math.floor(sample_offsets[segind + 1]) - 1

        t1_min = t1 / ds_params['samplerate'] / 60
        t2_min = t2 / ds_params['samplerate'] / 60
        logger.info(f'Segment {segind + 1}: t1={t1}, t2={t2}, '
                    f't1_min={t1_min:.3f}, t2_min={t2_min:.3f}')

        pre_outpath = os.path.join(dataset_dir, f'pri-{segind + 1}.mda')
        pyms_extract_segment(timeseries=os.path.join(output_dir,
                                                     'pre.mda.prv'),
                             timeseries_out=pre_outpath,
                             t1=t1,
                             t2=t2,
                             opts=opts)

        firings_outpath = os.path.join(dataset_dir,
                                       f'firings-{segind + 1}.mda')
        ms4alg(timeseries=pre_outpath,
               firings_out=firings_outpath,
               geom=geom,
               detect_sign=int(detect_sign),
               adjacency_radius=adjacency_radius,
               detect_threshold=detect_threshold,
               detect_interval=detect_interval,
               num_workers=num_workers,
               opts=opts)

        firings_list.append(firings_outpath)
        timeseries_list.append(pre_outpath)

    firings_out_final = os.path.join(output_dir, 'firings_raw.mda')
    txt_out = os.path.join(output_dir, 'firings_raw_anneal_log.json')
    # sample_offsets have to be converted into a string to be properly passed
    # into the processor
    str_sample_offsets = ','.join(map(str, sample_offsets))
    logger.info(str_sample_offsets)

    # Drift tracking
    pyms_anneal_segs(timeseries_list=timeseries_list,
                     firings_list=firings_list,
                     firings_out=firings_out_final,
                     text_out=txt_out,
                     dmatrix_out=[],
                     k1_dmatrix_out=[],
                     k2_dmatrix_out=[],
                     dmatrix_templates_out=[],
                     time_offsets=str_sample_offsets)

    # clear the temp pre and firings files if specified
    if rm_segment_intermediates:
        clear_seg_files(timeseries_list=timeseries_list,
                        firings_list=firings_list)

    # Compute cluster metrics
    compute_cluster_metrics(timeseries=os.path.join(output_dir, 'pre.mda.prv'),
                            firings=os.path.join(output_dir,
                                                 'firings_raw.mda'),
                            metrics_out=os.path.join(output_dir,
                                                     'metrics_raw.json'),
                            samplerate=ds_params['samplerate'],
                            opts=opts)

    recalc_metrics_epoch_electrode(
        params=('', output_dir, '', mda_opts),
        rm_segment_intermediates=rm_segment_intermediates,
        updated_mda='firings_raw.mda',
        mv2_file='',
        manual_only=False)