def test_clusters_metrics(): np.random.seed(54) rec_length = 1000 frs = np.array([3, 100, 80, 40]) # firing rates cid = [0, 1, 3, 4] # here we make sure one of the clusters has no spike t, a, c = multiple_spike_trains(firing_rates=frs, rec_len_secs=rec_length, cluster_ids=cid) d = np.sin(2 * np.pi * c / rec_length * t) * 100 # sinusoidal shift where cluster id drives f def _assertions(dfm, idf, target_cid): # dfm: qc dataframe, idf: indices of existing clusters in dfm, cid: cluster ids assert np.allclose(dfm['amp_median'][idf] / np.exp(5.5) * 1e6, 1, rtol=1.1) assert np.allclose(dfm['amp_std_dB'][idf] / 20 * np.log10(np.exp(0.5)), 1, rtol=1.1) assert np.allclose(dfm['drift'][idf], np.array(cid) * 100 * 4 * 3.6, rtol=1.1) assert np.allclose(dfm['firing_rate'][idf], frs, rtol=1.1) assert np.allclose(dfm['cluster_id'], target_cid) # # check with straight indexing # dfm = quick_unit_metrics(c, t, a, d) # _assertions(dfm, np.arange(4), cid) # check with missing clusters dfm = quick_unit_metrics(c, t, a, d, cluster_ids=np.arange(5)) idf, _ = ismember(np.arange(5), cid) _assertions(dfm, idf, np.arange(5))
def unit_metrics_ks2(ks2_path=None, m=None, save=True): """ Given a path containing kilosort 2 output, compute quality metrics and optionally save them to a clusters_metric.csv file :param ks2_path: :param save :return: """ # ensure that either a ks2_path or a phylib `TemplateModel` object with unit info is given assert not(ks2_path is None and m is None), 'Must either specify a path to a ks2 output ' \ 'directory, or a phylib `TemplateModel` object' # create phylib `TemplateModel` if not given m = phy_model_from_ks2_path(ks2_path) if None else m # compute metrics and convert to `DataFrame` r = pd.DataFrame( quick_unit_metrics(m.spike_clusters, m.spike_times, m.amplitudes, m.depths)) # TODO compute drift as a function of time here # TODO compute metrics using sample waveforms here # compute labels based on metrics df_labels = pd.DataFrame( unit_labels(m.spike_clusters, m.spike_times, m.amplitudes)) r = r.set_index('cluster_id', drop=False).join(df_labels.set_index('cluster_id')) # include the ks2 cluster contamination if `cluster_ContamPct` file exists file_contamination = ks2_path.joinpath('cluster_ContamPct.tsv') if file_contamination.exists(): contam = pd.read_csv(file_contamination, sep='\t') contam.rename(columns={'ContamPct': 'ks2_contamination_pct'}, inplace=True) r = r.set_index('cluster_id', drop=False).join(contam.set_index('cluster_id')) # include the ks2 cluster labels if `cluster_KSLabel` file exists file_labels = ks2_path.joinpath('cluster_KSLabel.tsv') if file_labels.exists(): ks2_labels = pd.read_csv(file_labels, sep='\t') ks2_labels.rename(columns={'KSLabel': 'ks2_label'}, inplace=True) r = r.set_index('cluster_id', drop=False).join(ks2_labels.set_index('cluster_id')) if save: # the file name contains the label of the probe (directory name in this case) r.to_csv(ks2_path.joinpath('cluster_metrics.csv')) return r
def test_clusters_metrics(): frs = [3, 200, 259, 567] # firing rates t, a, c = multiple_spike_trains(firing_rates=frs, rec_len_secs=1000, cluster_ids=[0, 1, 3, 4]) d = np.sin(2 * np.pi * c / 1000 * t) * 100 # sinusoidal shift where cluster id drives period dfm = quick_unit_metrics(c, t, a, d) assert np.allclose(dfm['amp_median'] / np.exp(5.5) * 1e6, 1, rtol=1.1) assert np.allclose(dfm['amp_std_dB'] / 20 * np.log10(np.exp(0.5)), 1, rtol=1.1) assert np.allclose(dfm['drift'], np.array([0, 1, 3, 4]) * 100 * 4 * 3.6, rtol=1.1) np.allclose(dfm['firing_rate'], frs)