Exemple #1
0
def test_clusters_metrics():
    np.random.seed(54)
    rec_length = 1000
    frs = np.array([3, 100, 80, 40])  # firing rates
    cid = [0, 1, 3, 4]  # here we make sure one of the clusters has no spike
    t, a, c = multiple_spike_trains(firing_rates=frs,
                                    rec_len_secs=rec_length,
                                    cluster_ids=cid)
    d = np.sin(2 * np.pi * c / rec_length *
               t) * 100  # sinusoidal shift where cluster id drives f

    def _assertions(dfm, idf, target_cid):
        # dfm: qc dataframe, idf: indices of existing clusters in dfm, cid: cluster ids
        assert np.allclose(dfm['amp_median'][idf] / np.exp(5.5) * 1e6,
                           1,
                           rtol=1.1)
        assert np.allclose(dfm['amp_std_dB'][idf] / 20 * np.log10(np.exp(0.5)),
                           1,
                           rtol=1.1)
        assert np.allclose(dfm['drift'][idf],
                           np.array(cid) * 100 * 4 * 3.6,
                           rtol=1.1)
        assert np.allclose(dfm['firing_rate'][idf], frs, rtol=1.1)
        assert np.allclose(dfm['cluster_id'], target_cid)

    # # check with straight indexing
    # dfm = quick_unit_metrics(c, t, a, d)
    # _assertions(dfm, np.arange(4), cid)

    # check with missing clusters
    dfm = quick_unit_metrics(c, t, a, d, cluster_ids=np.arange(5))
    idf, _ = ismember(np.arange(5), cid)
    _assertions(dfm, idf, np.arange(5))
Exemple #2
0
def unit_metrics_ks2(ks2_path=None, m=None, save=True):
    """
    Given a path containing kilosort 2 output, compute quality metrics and optionally save them
    to a clusters_metric.csv file
    :param ks2_path:
    :param save
    :return:
    """

    # ensure that either a ks2_path or a phylib `TemplateModel` object with unit info is given
    assert not(ks2_path is None and m is None), 'Must either specify a path to a ks2 output ' \
                                                'directory, or a phylib `TemplateModel` object'
    # create phylib `TemplateModel` if not given
    m = phy_model_from_ks2_path(ks2_path) if None else m
    # compute metrics and convert to `DataFrame`
    r = pd.DataFrame(
        quick_unit_metrics(m.spike_clusters, m.spike_times, m.amplitudes,
                           m.depths))
    # TODO compute drift as a function of time here
    # TODO compute metrics using sample waveforms here

    # compute labels based on metrics
    df_labels = pd.DataFrame(
        unit_labels(m.spike_clusters, m.spike_times, m.amplitudes))
    r = r.set_index('cluster_id',
                    drop=False).join(df_labels.set_index('cluster_id'))

    #  include the ks2 cluster contamination if `cluster_ContamPct` file exists
    file_contamination = ks2_path.joinpath('cluster_ContamPct.tsv')
    if file_contamination.exists():
        contam = pd.read_csv(file_contamination, sep='\t')
        contam.rename(columns={'ContamPct': 'ks2_contamination_pct'},
                      inplace=True)
        r = r.set_index('cluster_id',
                        drop=False).join(contam.set_index('cluster_id'))

    #  include the ks2 cluster labels if `cluster_KSLabel` file exists
    file_labels = ks2_path.joinpath('cluster_KSLabel.tsv')
    if file_labels.exists():
        ks2_labels = pd.read_csv(file_labels, sep='\t')
        ks2_labels.rename(columns={'KSLabel': 'ks2_label'}, inplace=True)
        r = r.set_index('cluster_id',
                        drop=False).join(ks2_labels.set_index('cluster_id'))

    if save:
        #  the file name contains the label of the probe (directory name in this case)
        r.to_csv(ks2_path.joinpath('cluster_metrics.csv'))

    return r
def test_clusters_metrics():
    frs = [3, 200, 259, 567]  # firing rates
    t, a, c = multiple_spike_trains(firing_rates=frs,
                                    rec_len_secs=1000,
                                    cluster_ids=[0, 1, 3, 4])
    d = np.sin(2 * np.pi * c / 1000 *
               t) * 100  # sinusoidal shift where cluster id drives period
    dfm = quick_unit_metrics(c, t, a, d)

    assert np.allclose(dfm['amp_median'] / np.exp(5.5) * 1e6, 1, rtol=1.1)
    assert np.allclose(dfm['amp_std_dB'] / 20 * np.log10(np.exp(0.5)),
                       1,
                       rtol=1.1)
    assert np.allclose(dfm['drift'],
                       np.array([0, 1, 3, 4]) * 100 * 4 * 3.6,
                       rtol=1.1)

    np.allclose(dfm['firing_rate'], frs)