def test_spike_detect_real_data(spikedetekt): with TemporaryDirectory() as tempdir: # Set the parameters. curdir = op.dirname(op.realpath(__file__)) default_settings_path = op.join(curdir, '../default_settings.py') settings = _read_python(default_settings_path) sample_rate = 20000 params = settings['spikedetekt'] params['sample_rate'] = sample_rate n_channels = 32 npc = params['n_features_per_channel'] n_samples_w = params['extract_s_before'] + params['extract_s_after'] probe = load_probe('1x32_buzsaki') # Load the traces. path = _download_test_data('test-32ch-10s.dat') traces = np.fromfile(path, dtype=np.int16).reshape((200000, 32)) # Run the detection. sd = SpikeDetekt(tempdir=tempdir, probe=probe, **params) out = sd.run_serial(traces, interval_samples=(0, 50000)) n_spikes = out.n_spikes_total def _concat(arrs): return np.concatenate(arrs) spike_samples = _concat(out.spike_samples[0]) masks = _concat(out.masks[0]) features = _concat(out.features[0]) assert spike_samples.shape == (n_spikes,) assert masks.shape == (n_spikes, n_channels) assert features.shape == (n_spikes, n_channels, npc) # There should not be any spike with only masked channels. assert np.all(masks.max(axis=1) > 0) # Plot... from phy.plot.traces import plot_traces c = plot_traces(traces[:30000, :], spike_samples=spike_samples, masks=masks, n_samples_per_spike=n_samples_w, show=False) show_test(c)
def test_spike_detect_real_data(tempdir, raw_dataset): params = raw_dataset.params probe = raw_dataset.probe sample_rate = raw_dataset.sample_rate sd = SpikeDetekt(tempdir=tempdir, probe=probe, sample_rate=sample_rate, **params) traces = raw_dataset.traces n_samples = raw_dataset.n_samples npc = params['n_features_per_channel'] n_samples_w = params['extract_s_before'] + params['extract_s_after'] # Run the detection. out = sd.run_serial(traces, interval_samples=(0, n_samples)) channels = probe['channel_groups'][0]['channels'] n_channels = len(channels) spike_samples = _concatenate(out.spike_samples[0]) masks = _concatenate(out.masks[0]) features = _concatenate(out.features[0]) n_spikes = out.n_spikes_per_group[0] if n_spikes: assert spike_samples.shape == (n_spikes,) assert masks.shape == (n_spikes, n_channels) assert features.shape == (n_spikes, n_channels, npc) # There should not be any spike with only masked channels. assert np.all(masks.max(axis=1) > 0) # Plot... from phy.plot.traces import plot_traces c = plot_traces(traces[:30000, channels], spike_samples=spike_samples, masks=masks, n_samples_per_spike=n_samples_w, show=False) show_test(c)