예제 #1
0
sw.plot_unit_probe_map(we, unit_ids=unit_ids)

##############################################################################
# plot_unit_waveform_density_map()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# This is your best friend to check over merge

unit_ids = sorting.unit_ids[:4]
sw.plot_unit_waveform_density_map(we, unit_ids=unit_ids, max_channels=5)

##############################################################################
# plot_amplitudes_distribution()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

sw.plot_amplitudes_distribution(we)

##############################################################################
# plot_amplitudes_timeseres()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~

sw.plot_amplitudes_timeseries(we)

##############################################################################
# plot_units_depth_vs_amplitude()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

sw.plot_units_depth_vs_amplitude(we)

plt.show()
예제 #2
0
 def test_amplitudes_distribution(self):
     sw.plot_amplitudes_distribution(self._we)
     sw.plot_amplitudes_distribution(self._we, amplitudes=self._amplitudes)
예제 #3
0
def visualize_units(recording, recording_f, sorting, wf, templates,
                    num_channels, file_name):
    wf2 = st.postprocessing.get_unit_waveforms(recording,
                                               sorting,
                                               ms_before=4,
                                               ms_after=0,
                                               verbose=True)
    color = cm.rainbow(np.linspace(0, 1, len(wf)))
    output_dir = "waveform_visualization/" + file_name + "/"
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    fig, axs = plt.subplots(3, 2, constrained_layout=True)
    axe = axs.ravel()
    sw.plot_timeseries(recording_f, ax=axe[0])
    for i in range(len(wf)):
        normalized = preprocessing.normalize(np.array(templates)[i][:, :])
        axe[1].plot(normalized.T, label=i + 1, color=color[i])
    sw.plot_rasters(sorting, ax=axe[2])
    sw.plot_pca_features(recording,
                         sorting,
                         colormap='rainbow',
                         nproj=3,
                         max_spikes_per_unit=100,
                         axes=axe[3:6])
    legend = axe[1].legend(loc='lower right')
    fig.set_size_inches((8.5, 11), forward=False)
    # name = file_name.split("_")
    # # acc_1_04052016_kurocoppola_pre
    fig.suptitle(file_name)
    fig.savefig(output_dir + '_full_info.png', dpi=500)
    for idx in range(len(wf)):
        for j in range(num_channels):
            fig, axs = plt.subplots(4, 3, constrained_layout=True)
            axe = axs.ravel()
            normalized = preprocessing.normalize(wf[idx][:, j, :])
            axe[2].plot(normalized.T, lw=0.3, alpha=0.3)
            normalized = preprocessing.normalize(
                np.array(templates)[idx][:, :])
            axe[2].plot(normalized.T, lw=0.8, color='black')
            axe[0].plot(templates[idx].T, lw=0.8)
            for i in range(len(wf)):
                normalized = preprocessing.normalize(
                    np.array(templates)[i][:, :])
                if idx == i:
                    axe[1].plot(normalized.T, label=i + 1)
                else:
                    axe[1].plot(normalized.T, label=i + 1, alpha=0.3)
            values = randint(0, len(wf), 3)
            for i, val in enumerate(values):
                axe[3 + i].plot(wf2[idx][val, j, :].T)
            sw.plot_pca_features(recording,
                                 sorting,
                                 colormap='binary',
                                 figure=fig,
                                 nproj=3,
                                 axes=axe[6:9],
                                 max_spikes_per_unit=100)
            sw.plot_pca_features(recording,
                                 sorting,
                                 unit_ids=[idx],
                                 figure=fig,
                                 nproj=3,
                                 axes=axe[6:9],
                                 max_spikes_per_unit=100)
            sw.plot_isi_distribution(sorting,
                                     unit_ids=[idx + 1],
                                     bins=100,
                                     window=1,
                                     axes=axe[9:11])
            sw.plot_amplitudes_distribution(recording,
                                            sorting,
                                            max_spikes_per_unit=300,
                                            unit_ids=[idx + 1],
                                            axes=axe[10:12])
            sp = sorting.get_unit_spike_train(unit_id=idx + 1)
            axe[11].plot(sp)
            legend = axe[1].legend(loc='lower right', shadow=True)
            axe[0].set_title('Template')
            axe[2].set_title('Waveforms, template')
            axe[3].set_title('Random waveforms')
            axe[6].set_title('PCA component analysis')
            axe[9].set_title('ISI distribution')
            axe[10].set_title('Amplitude distribution')
            axe[11].set_title('Spike frame')
            fig.set_size_inches((8.5, 11), forward=False)
            fig.suptitle("File name: " + file_name + '\nChannel: ' +
                         str(j + 1).zfill(2) + '\nUnit: ' +
                         str(idx + 1).zfill(2))
            fig.savefig(output_dir + file_name + '_ch' + str(j + 1).zfill(2) +
                        '_u' + str(idx + 1).zfill(2) + '.png',
                        dpi=500)
    plt.close('all')
예제 #4
0
def export_report(waveform_extractor,
                  output_folder,
                  remove_if_exists=False,
                  format="png",
                  show_figures=False,
                  peak_sign='neg',
                  **job_kwargs):
    """
    Exports a SI spike sorting report. The report includes summary figures of the spike sorting output
    (e.g. amplitude distributions, unit localization and depth VS amplitude) as well as unit-specific reports,
    that include waveforms, templates, template maps, ISI distributions, and more.
    
    
    Parameters
    ----------
    waveform_extractor: a WaveformExtractor or None
        If WaveformExtractor is provide then the compute is faster otherwise
    output_folder: str
        The output folder where the report files are saved
    remove_if_exists: bool
        If True and the output folder exists, it is removed
    format: str
        'png' (default) or 'pdf' or any format handled by matplotlib
    peak_sign: 'neg' or 'pos'
        used to compute amplitudes and metrics
    show_figures: bool
        If True, figures are shown. If False (default), figures are closed after saving.
    {}
    """
    we = waveform_extractor
    sorting = we.sorting
    unit_ids = sorting.unit_ids

    # lets matplotlib do this check svg is also cool
    # assert format in ["png", "pdf"], "'format' can be 'png' or 'pdf'"

    if we.is_extension('spike_amplitudes'):
        sac = we.load_extension('spike_amplitudes')
        amplitudes = sac.get_amplitudes(outputs='by_unit')
    else:
        amplitudes = st.compute_spike_amplitudes(we,
                                                 peak_sign=peak_sign,
                                                 outputs='by_unit',
                                                 **job_kwargs)

    output_folder = Path(output_folder).absolute()
    if output_folder.is_dir():
        if remove_if_exists:
            shutil.rmtree(output_folder)
        else:
            raise FileExistsError(f'{output_folder} already exists')
    output_folder.mkdir(parents=True, exist_ok=True)

    # unit list
    units = pd.DataFrame(
        index=unit_ids)  #  , columns=['max_on_channel_id', 'amplitude'])
    units.index.name = 'unit_id'
    units['max_on_channel_id'] = pd.Series(
        st.get_template_extremum_channel(we, peak_sign='neg', outputs='id'))
    units['amplitude'] = pd.Series(
        st.get_template_extremum_amplitude(we, peak_sign='neg'))
    units.to_csv(output_folder / 'unit list.csv', sep='\t')

    # metrics
    if we.is_extension('quality_metrics'):
        qmc = we.load_extension('quality_metrics')
        metrics = qmc._metrics
    else:
        # compute principal_components if not done
        if not we.is_extension('principal_components'):
            pca = st.compute_principal_components(we,
                                                  load_if_exists=True,
                                                  n_components=5,
                                                  mode='by_channel_local')
        metrics = st.compute_quality_metrics(we)
    metrics.to_csv(output_folder / 'quality metrics.csv')

    unit_colors = sw.get_unit_colors(sorting)

    # global figures
    fig = plt.figure(figsize=(20, 10))
    w = sw.plot_unit_localization(we, figure=fig, unit_colors=unit_colors)
    fig.savefig(output_folder / f'unit_localization.{format}')
    if not show_figures:
        plt.close(fig)

    fig, ax = plt.subplots(figsize=(20, 10))
    sw.plot_units_depth_vs_amplitude(we, ax=ax, unit_colors=unit_colors)
    fig.savefig(output_folder / f'units_depth_vs_amplitude.{format}')
    if not show_figures:
        plt.close(fig)

    fig = plt.figure(figsize=(20, 10))
    sw.plot_amplitudes_distribution(we, figure=fig, unit_colors=unit_colors)
    fig.savefig(output_folder / f'amplitudes_distribution.{format}')
    if not show_figures:
        plt.close(fig)

    # units
    units_folder = output_folder / 'units'
    units_folder.mkdir()

    for unit_id in unit_ids:
        fig = plt.figure(
            constrained_layout=False,
            figsize=(15, 7),
        )
        sw.plot_unit_summary(we, unit_id, figure=fig)
        fig.suptitle(f'unit {unit_id}')
        fig.savefig(units_folder / f'{unit_id}.{format}')
        if not show_figures:
            plt.close(fig)
import spikeinterface.widgets as sw

##############################################################################
# First, let's create a toy example with the `extractors` module:

recording, sorting = se.example_datasets.toy_example(duration=10, num_channels=4, seed=0)

##############################################################################
# plot_unit_waveforms()
# ~~~~~~~~~~~~~~~~~~~~~~~~

w_wf = sw.plot_unit_waveforms(recording, sorting, max_num_waveforms=100)

##############################################################################
# plot_amplitudes_distribution()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

w_ampd = sw.plot_amplitudes_distribution(recording, sorting, max_num_waveforms=300)

##############################################################################
# plot_amplitudes_timeseres()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

w_ampt = sw.plot_amplitudes_timeseries(recording, sorting, max_num_waveforms=300)

##############################################################################
# plot_features()
# ~~~~~~~~~~~~~~~~~~~~~~~~

w_feat = sw.plot_features(recording, sorting, colormap='rainbow', nproj=3, max_num_waveforms=100)
예제 #6
0
recording, sorting = se.example_datasets.toy_example(duration=10,
                                                     num_channels=4,
                                                     seed=0)

##############################################################################
# plot_unit_waveforms()
# ~~~~~~~~~~~~~~~~~~~~~~~~

w_wf = sw.plot_unit_waveforms(recording, sorting, max_spikes_per_unit=100)

##############################################################################
# plot_amplitudes_distribution()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

w_ampd = sw.plot_amplitudes_distribution(recording,
                                         sorting,
                                         max_spikes_per_unit=300)

##############################################################################
# plot_amplitudes_timeseres()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

w_ampt = sw.plot_amplitudes_timeseries(recording,
                                       sorting,
                                       max_spikes_per_unit=300)

##############################################################################
# plot_pca_features()
# ~~~~~~~~~~~~~~~~~~~~~~~~

w_feat = sw.plot_pca_features(recording,