recording, sorting_true = se.example_datasets.toy_example(duration=10,
                                                          num_channels=4,
                                                          seed=0)

##############################################################################
# :code:`recording` is a :code:`RecordingExtractor` object, which extracts information about channel ids, channel locations
# (if present), the sampling frequency of the recording, and the extracellular  traces. :code:`sorting_true` is a
# :code:`SortingExtractor` object, which contains information about spike-sorting related information,  including unit ids,
# spike trains, etc. Since the data are simulated, :code:`sorting_true` has ground-truth information of the spiking
# activity of each unit.
#
# Let's use the :code:`widgets` module to visualize the traces and the raster plots.

w_ts = sw.plot_timeseries(recording, trange=[0, 5])
w_rs = sw.plot_rasters(sorting_true, trange=[0, 5])

##############################################################################
# This is how you retrieve info from a :code:`RecordingExtractor`...

channel_ids = recording.get_channel_ids()
fs = recording.get_sampling_frequency()
num_chan = recording.get_num_channels()

print('Channel ids:', channel_ids)
print('Sampling frequency:', fs)
print('Number of channels:', num_chan)

##############################################################################
# ...and a :code:`SortingExtractor`
unit_ids = sorting_true.get_unit_ids()
Ejemplo n.º 2
0
 def test_rasters(self):
     sw.plot_rasters(self._sorting)
Ejemplo n.º 3
0
import spikeinterface.extractors as se
import spikeinterface.widgets as sw

##############################################################################
# First, let's create a toy example with the `extractors` module:

recording, sorting = se.toy_example(duration=10,
                                    num_channels=4,
                                    seed=0,
                                    num_segments=1)

##############################################################################
# plot_rasters()
# ~~~~~~~~~~~~~~~~~

w_rs = sw.plot_rasters(sorting)

##############################################################################
# plot_isi_distribution()
# ~~~~~~~~~~~~~~~~~~~~~~~~

#TODO : @alessio: this is for you
#w_isi = sw.plot_isi_distribution(sorting, bins=10, window=1)

##############################################################################
# plot_autocorrelograms()
# ~~~~~~~~~~~~~~~~~~~~~~~~

#TODO : @alessio: this is for you
# w_ach = sw.plot_autocorrelograms(sorting, bin_size=1, window=10, unit_ids=[1, 2, 4, 5, 8, 10, 7])
Ejemplo n.º 4
0
local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5')
recording, sorting_true = se.read_mearec(local_path)
print(recording)
print(sorting_true)

##############################################################################
# :code:`recording` is a :code:`RecordingExtractor` object, which extracts information about channel ids, channel locations
# (if present), the sampling frequency of the recording, and the extracellular  traces. :code:`sorting_true` is a
# :code:`SortingExtractor` object, which contains information about spike-sorting related information,  including unit ids,
# spike trains, etc. Since the data are simulated, :code:`sorting_true` has ground-truth information of the spiking
# activity of each unit.
#
# Let's use the :code:`widgets` module to visualize the traces and the raster plots.

w_ts = sw.plot_timeseries(recording, time_range=(0, 5))
w_rs = sw.plot_rasters(sorting_true, time_range=(0, 5))

##############################################################################
# This is how you retrieve info from a :code:`RecordingExtractor`...

channel_ids = recording.get_channel_ids()
fs = recording.get_sampling_frequency()
num_chan = recording.get_num_channels()
num_seg = recording.get_num_segments()

print('Channel ids:', channel_ids)
print('Sampling frequency:', fs)
print('Number of channels:', num_chan)
print('Number of segments:', num_seg)

##############################################################################
times_reference_joined = np.concatenate(
    (times_reference[0], times_reference2[0], times_reference3[0]))
labels_joined = np.concatenate((labels, labels2, labels3), axis=0)

# print(times_reference_joined.shape)
# print(labels_joined.shape)

sortingPipeline.set_times_labels(times=times_reference_joined,
                                 labels=labels_joined)
sortingPipeline.set_sampling_frequency(sampling_frequency=Fs)
print('Unit ids = {}'.format(sortingPipeline.get_unit_ids()))
st = sortingPipeline.get_unit_spike_train(unit_id=1)
print('Num. events for unit 1 = {}'.format(len(st)))
st1 = sortingPipeline.get_unit_spike_train(unit_id=1)
print('Num. events for first second of unit 1 = {}'.format(len(st1)))
w_rs_gt = sw.plot_rasters(sortingPipeline, sampling_frequency=Fs)
plt.show()

#LOADING RECORDING TO RUN OTHER ALGOS-MS4 AND SPYKING CIRCUS

import os
script_dir = os.path.dirname(__file__)
file_path = os.path.join(
    script_dir,
    '/home/maitreyi/spikeforest_recordings/recordings/SYNTH_MEAREC_TETRODE/synth_mearec_tetrode_noise10_K10_C4/synth_mearec_tetrode_noise10_K10_C4.json'
)
with open(file_path, 'r') as fi:
    spec = json.load(fi)

#Step 2-Parsing the jason file for input data
name = spec['name']
Ejemplo n.º 6
0
def visualize_units(recording, recording_f, sorting, wf, templates,
                    num_channels, file_name):
    wf2 = st.postprocessing.get_unit_waveforms(recording,
                                               sorting,
                                               ms_before=4,
                                               ms_after=0,
                                               verbose=True)
    color = cm.rainbow(np.linspace(0, 1, len(wf)))
    output_dir = "waveform_visualization/" + file_name + "/"
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    fig, axs = plt.subplots(3, 2, constrained_layout=True)
    axe = axs.ravel()
    sw.plot_timeseries(recording_f, ax=axe[0])
    for i in range(len(wf)):
        normalized = preprocessing.normalize(np.array(templates)[i][:, :])
        axe[1].plot(normalized.T, label=i + 1, color=color[i])
    sw.plot_rasters(sorting, ax=axe[2])
    sw.plot_pca_features(recording,
                         sorting,
                         colormap='rainbow',
                         nproj=3,
                         max_spikes_per_unit=100,
                         axes=axe[3:6])
    legend = axe[1].legend(loc='lower right')
    fig.set_size_inches((8.5, 11), forward=False)
    # name = file_name.split("_")
    # # acc_1_04052016_kurocoppola_pre
    fig.suptitle(file_name)
    fig.savefig(output_dir + '_full_info.png', dpi=500)
    for idx in range(len(wf)):
        for j in range(num_channels):
            fig, axs = plt.subplots(4, 3, constrained_layout=True)
            axe = axs.ravel()
            normalized = preprocessing.normalize(wf[idx][:, j, :])
            axe[2].plot(normalized.T, lw=0.3, alpha=0.3)
            normalized = preprocessing.normalize(
                np.array(templates)[idx][:, :])
            axe[2].plot(normalized.T, lw=0.8, color='black')
            axe[0].plot(templates[idx].T, lw=0.8)
            for i in range(len(wf)):
                normalized = preprocessing.normalize(
                    np.array(templates)[i][:, :])
                if idx == i:
                    axe[1].plot(normalized.T, label=i + 1)
                else:
                    axe[1].plot(normalized.T, label=i + 1, alpha=0.3)
            values = randint(0, len(wf), 3)
            for i, val in enumerate(values):
                axe[3 + i].plot(wf2[idx][val, j, :].T)
            sw.plot_pca_features(recording,
                                 sorting,
                                 colormap='binary',
                                 figure=fig,
                                 nproj=3,
                                 axes=axe[6:9],
                                 max_spikes_per_unit=100)
            sw.plot_pca_features(recording,
                                 sorting,
                                 unit_ids=[idx],
                                 figure=fig,
                                 nproj=3,
                                 axes=axe[6:9],
                                 max_spikes_per_unit=100)
            sw.plot_isi_distribution(sorting,
                                     unit_ids=[idx + 1],
                                     bins=100,
                                     window=1,
                                     axes=axe[9:11])
            sw.plot_amplitudes_distribution(recording,
                                            sorting,
                                            max_spikes_per_unit=300,
                                            unit_ids=[idx + 1],
                                            axes=axe[10:12])
            sp = sorting.get_unit_spike_train(unit_id=idx + 1)
            axe[11].plot(sp)
            legend = axe[1].legend(loc='lower right', shadow=True)
            axe[0].set_title('Template')
            axe[2].set_title('Waveforms, template')
            axe[3].set_title('Random waveforms')
            axe[6].set_title('PCA component analysis')
            axe[9].set_title('ISI distribution')
            axe[10].set_title('Amplitude distribution')
            axe[11].set_title('Spike frame')
            fig.set_size_inches((8.5, 11), forward=False)
            fig.suptitle("File name: " + file_name + '\nChannel: ' +
                         str(j + 1).zfill(2) + '\nUnit: ' +
                         str(idx + 1).zfill(2))
            fig.savefig(output_dir + file_name + '_ch' + str(j + 1).zfill(2) +
                        '_u' + str(idx + 1).zfill(2) + '.png',
                        dpi=500)
    plt.close('all')
Ejemplo n.º 7
0
def get_sort_info(sorting, recording, out_loc):
    unit_ids = sorting.get_unit_ids()
    print("Found", len(unit_ids), 'units')
    print('Unit ids:', unit_ids)

    spike_train = sorting.get_unit_spike_train(unit_id=unit_ids[0])
    print('Spike train of first unit:', np.asarray(spike_train) / 48000)

    # Spike raster plot
    t_len = 10
    o_loc = os.path.join(out_loc, "raster_" + str(t_len) + "s.png")
    print("Saving {}s rasters to {}".format(t_len, o_loc))
    w_rs = sw.plot_rasters(sorting, trange=[0, t_len])
    plt.savefig(o_loc, dpi=200)

    w_isi = sw.plot_isi_distribution(sorting, bins=10, window=1)
    o_loc = os.path.join(out_loc, "isi.png")
    print("Saving isi to {}".format(o_loc))
    plt.savefig(o_loc, dpi=200)

    # Can plot cross corr using - ignore for now
    # w_cch = sw.plot_crosscorrelograms(
    # sorting, unit_ids=[1, 5, 8], bin_size=0.1, window=5)
    w_feat = sw.plot_pca_features(recording,
                                  sorting,
                                  colormap='rainbow',
                                  nproj=3,
                                  max_spikes_per_unit=100)
    o_loc = os.path.join(out_loc, "pca.png")
    print("Plotting pca to {}".format(o_loc))
    plt.savefig(o_loc, dpi=200)

    # See also spiketoolkit.postprocessing.get_unit_waveforms
    num_samps = min(20, len(unit_ids))
    w_wf = sw.plot_unit_waveforms(sorting=sorting,
                                  recording=recording,
                                  unit_ids=unit_ids[:num_samps],
                                  max_spikes_per_unit=20)
    o_loc = os.path.join(out_loc, "waveforms_" + str(num_samps) + ".png")
    print("Saving {} waveforms to {}".format(num_samps, o_loc))
    plt.savefig(o_loc, dpi=200)

    wf_by_group = st.postprocessing.get_unit_waveforms(
        recording,
        sorting,
        ms_before=1,
        ms_after=2,
        save_as_features=False,
        verbose=True,
        grouping_property="group",
        compute_property_from_recording=True)
    o_loc = os.path.join(out_loc, "chan0_forms.png")
    fig, ax = plt.subplots()
    wf = wf_by_group[0]
    colors = ["k", "r", "b", "g"]
    for i in range(wf.shape[1]):
        wave = wf[:, i, :]
        c = colors[i]
        ax.plot(wave.T, color=c, lw=0.3)
    print("Saving first waveform on the first tetrode to {}".format(o_loc))
    fig.savefig(o_loc, dpi=200)
##############################################################################
# First, let's create a toy example with the `extractors` module:

recording, sorting_true = se.example_datasets.toy_example(duration=10, num_channels=4, seed=0)

##############################################################################
# `recording` is a `RecordingExtractor` object, which extracts information #  about channel ids, channel locations
# (if present), the sampling frequency of the recording, and the extracellular  traces. `sorting_true` is a
# `SortingExtractor` object, which contains information about spike-sorting related information,  including unit ids,
# spike trains, etc. Since the data are simulated, `sorting_true` has ground-truth information of the spiking
# activity of each unit.
#
#  Let's use the widgets to visualize the traces and the raster plots.

w_ts = sw.plot_timeseries(recording)
w_rs = sw.plot_rasters(sorting_true)

##############################################################################
# This is how you retrieve info from a `RecordingExtractor`...

channel_ids = recording.get_channel_ids()
fs = recording.get_sampling_frequency()
num_chan = recording.get_num_channels()

print('Channel ids:', channel_ids)
print('Sampling frequency:', fs)
print('Number of channels:', num_chan)

##############################################################################
# ...and a `SortingExtractor`
unit_ids = sorting_true.get_unit_ids()