Exemplo n.º 1
0
                          sorting,
                          folder,
                          load_if_exists=True,
                          ms_before=1,
                          ms_after=2.,
                          max_spikes_per_unit=500,
                          n_jobs=1,
                          chunk_size=30000)

##############################################################################
# plot_unit_waveforms()
# ~~~~~~~~~~~~~~~~~~~~~

unit_ids = sorting.unit_ids[:4]

sw.plot_unit_waveforms(we, unit_ids=unit_ids)

##############################################################################
# plot_unit_templates()
# ~~~~~~~~~~~~~~~~~~~~~

unit_ids = sorting.unit_ids

sw.plot_unit_templates(we, unit_ids=unit_ids, ncols=5)

##############################################################################
# plot_unit_probe_map()
# ~~~~~~~~~~~~~~~~~~~~~

unit_ids = sorting.unit_ids[:4]
sw.plot_unit_probe_map(we, unit_ids=unit_ids)
Exemplo n.º 2
0
#We are also going to be setting up the unit spike features associated with each waveform
ID1_features = A_snippets_reference[cluster == 1, :]
sortingPipeline.set_unit_spike_features(unit_id=1,
                                        feature_name='unitId1',
                                        value=ID1_features)
print("Spike feature names: " +
      str(sortingPipeline.get_unit_spike_feature_names(unit_id=1)))

#Comparing sorter with ground truth
cmp_gt_SP = sc.compare_sorter_to_ground_truth(gtOutput,
                                              sortingPipeline,
                                              exhaustive_gt=True)
sw.plot_agreement_matrix(cmp_gt_SP, ordered=True)

#Some more comparision metrics
perf = cmp_gt_SP.get_performance()
#print('well_detected', cmp_gt_SP.get_well_detected_units(well_detected_score=0))
print(perf)
#We will try to get the SNR and firing rates

#firing_rates = st.validation.compute_firing_rates(sortingPipeline, duration_in_frames=recordingInput.get_num_frames())

#Raster plots

w_rs_gt = sw.plot_rasters(sortingPipeline, sampling_frequency=sampleRate)

w_wf_gt = sw.plot_unit_waveforms(recordingInput,
                                 sortingPipeline,
                                 max_spikes_per_unit=100)
Exemplo n.º 3
0
 def test_unitwaveforms(self):
     w = sw.plot_unit_waveforms(self._we)
     unit_ids = self._sorting.unit_ids[:6]
     sw.plot_unit_waveforms(self._we, max_channels=5, unit_ids=unit_ids)
     sw.plot_unit_waveforms(self._we, radius_um=60, unit_ids=unit_ids)
Exemplo n.º 4
0
 def test_unitwaveforms(self):
     sw.plot_unit_waveforms(self._we)
     sw.plot_unit_waveforms(self._we, max_channels=5)
     sw.plot_unit_waveforms(self._we, radius_um=60)
Exemplo n.º 5
0
 def test_unitwaveforms(self):
     sw.plot_unit_waveforms(self._we)
Exemplo n.º 6
0
    11, 12, 14, 147, 239, 119, 241, 120, 245, 124, 246, 127, 249, 250, 1, 2,
    133, 135, 5, 6, 140, 10, 141, 143, 13, 144, 17, 238, 117
]) + 1

recordings_folder = Path.cwd().parent / 'recordings'
recording_file = recordings_folder / "2019-11-12T15-38-27McsRecording.h5"
recording = se.MCSH5RecordingExtractor(recording_file)
recording_prb = recording.load_probe_file(probe_file='OrgMEA_252.prb')

#remove ground electrodes
recording_gnd_rem = st.preprocessing.remove_bad_channels(
    recording_prb,
    bad_channel_ids=np.array([226, 95, 196, 62, 29, 166, 136, 128]) + 1)

#filter signal
recording_f = st.preprocessing.bandpass_filter(recording_gnd_rem,
                                               freq_min=200,
                                               freq_max=6000)

#remove bad channels
#recording_f_gnd_bad_rem = st.preprocessing.remove_bad_channels(recording_f_gnd_rem, bad_channels=np.array([XX, XX, XX, XX])+1)

#recording_rm_noise = st.preprocessing.remove_bad_channels(recording_f, bad_channels=[5])
recording_cmr = st.preprocessing.common_reference(recording_f,
                                                  reference='median')

save_folder = Path.cwd().parent / 'saved__sort_test.npz'
sorting = se.NpzSortingExtractor(save_folder)

w_wf2 = sw.plot_unit_waveforms(recording_f, sorting, max_spikes_per_unit=100)
Here is a gallery of all the available widgets using a pair of RecordingExtractor-SortingExtractor objects.
'''

import spikeinterface.extractors as se
import spikeinterface.widgets as sw

##############################################################################
# First, let's create a toy example with the `extractors` module:

recording, sorting = se.example_datasets.toy_example(duration=10, num_channels=4, seed=0)

##############################################################################
# plot_unit_waveforms()
# ~~~~~~~~~~~~~~~~~~~~~~~~

w_wf = sw.plot_unit_waveforms(recording, sorting, max_num_waveforms=100)

##############################################################################
# plot_amplitudes_distribution()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

w_ampd = sw.plot_amplitudes_distribution(recording, sorting, max_num_waveforms=300)

##############################################################################
# plot_amplitudes_timeseres()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

w_ampt = sw.plot_amplitudes_timeseries(recording, sorting, max_num_waveforms=300)

##############################################################################
# plot_features()
def get_sort_info(sorting, recording, out_loc):
    unit_ids = sorting.get_unit_ids()
    print("Found", len(unit_ids), 'units')
    print('Unit ids:', unit_ids)

    spike_train = sorting.get_unit_spike_train(unit_id=unit_ids[0])
    print('Spike train of first unit:', np.asarray(spike_train) / 48000)

    # Spike raster plot
    t_len = 10
    o_loc = os.path.join(out_loc, "raster_" + str(t_len) + "s.png")
    print("Saving {}s rasters to {}".format(t_len, o_loc))
    w_rs = sw.plot_rasters(sorting, trange=[0, t_len])
    plt.savefig(o_loc, dpi=200)

    w_isi = sw.plot_isi_distribution(sorting, bins=10, window=1)
    o_loc = os.path.join(out_loc, "isi.png")
    print("Saving isi to {}".format(o_loc))
    plt.savefig(o_loc, dpi=200)

    # Can plot cross corr using - ignore for now
    # w_cch = sw.plot_crosscorrelograms(
    # sorting, unit_ids=[1, 5, 8], bin_size=0.1, window=5)
    w_feat = sw.plot_pca_features(recording,
                                  sorting,
                                  colormap='rainbow',
                                  nproj=3,
                                  max_spikes_per_unit=100)
    o_loc = os.path.join(out_loc, "pca.png")
    print("Plotting pca to {}".format(o_loc))
    plt.savefig(o_loc, dpi=200)

    # See also spiketoolkit.postprocessing.get_unit_waveforms
    num_samps = min(20, len(unit_ids))
    w_wf = sw.plot_unit_waveforms(sorting=sorting,
                                  recording=recording,
                                  unit_ids=unit_ids[:num_samps],
                                  max_spikes_per_unit=20)
    o_loc = os.path.join(out_loc, "waveforms_" + str(num_samps) + ".png")
    print("Saving {} waveforms to {}".format(num_samps, o_loc))
    plt.savefig(o_loc, dpi=200)

    wf_by_group = st.postprocessing.get_unit_waveforms(
        recording,
        sorting,
        ms_before=1,
        ms_after=2,
        save_as_features=False,
        verbose=True,
        grouping_property="group",
        compute_property_from_recording=True)
    o_loc = os.path.join(out_loc, "chan0_forms.png")
    fig, ax = plt.subplots()
    wf = wf_by_group[0]
    colors = ["k", "r", "b", "g"]
    for i in range(wf.shape[1]):
        wave = wf[:, i, :]
        c = colors[i]
        ax.plot(wave.T, color=c, lw=0.3)
    print("Saving first waveform on the first tetrode to {}".format(o_loc))
    fig.savefig(o_loc, dpi=200)