def test_select_units():
    we = WaveformExtractor.load_from_folder('mearec_waveforms')
    amps = compute_spike_amplitudes(we, load_if_exists=True)

    keep_units = we.sorting.get_unit_ids()[::2]
    we_filt = we.select_units(keep_units, 'mearec_waveforms_filt')
    assert "spike_amplitudes" in we_filt.get_available_extension_names()
def test_compute_spike_amplitudes_parallel():
    repo = 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data'
    remote_path = 'mearec/mearec_test_10s.h5'
    local_path = download_dataset(repo=repo,
                                  remote_path=remote_path,
                                  local_folder=None)
    recording = se.MEArecRecordingExtractor(local_path)
    sorting = se.MEArecSortingExtractor(local_path)

    folder = Path('mearec_waveforms_all')

    we = extract_waveforms(recording,
                           sorting,
                           folder,
                           ms_before=1.,
                           ms_after=2.,
                           max_spikes_per_unit=None,
                           n_jobs=1,
                           chunk_size=30000,
                           load_if_exists=True)

    amplitudes1 = compute_spike_amplitudes(we,
                                           peak_sign='neg',
                                           load_if_exists=False,
                                           outputs='concatenated',
                                           chunk_size=10000,
                                           n_jobs=1)
    # TODO : fix multi processing for spike amplitudes!!!!!!!
    amplitudes2 = compute_spike_amplitudes(we,
                                           peak_sign='neg',
                                           load_if_exists=False,
                                           outputs='concatenated',
                                           chunk_size=10000,
                                           n_jobs=2)

    assert np.array_equal(amplitudes1[0], amplitudes2[0])
    def setUp(self):
        #~ self._rec, self._sorting = se.toy_example(num_channels=10, duration=10, num_segments=1)
        #~ self._rec = self._rec.save()
        #~ self._sorting = self._sorting.save()
        local_path = download_dataset(remote_path='mearec/mearec_test_10s.h5')
        self._rec = se.MEArecRecordingExtractor(local_path)

        self._sorting = se.MEArecSortingExtractor(local_path)

        self.num_units = len(self._sorting.get_unit_ids())
        #  self._we = extract_waveforms(self._rec, self._sorting, './toy_example', load_if_exists=True)
        self._we = extract_waveforms(self._rec,
                                     self._sorting,
                                     './mearec_test',
                                     load_if_exists=True)

        self._amplitudes = st.compute_spike_amplitudes(self._we,
                                                       peak_sign='neg',
                                                       outputs='by_unit')
        self._gt_comp = sc.compare_sorter_to_ground_truth(
            self._sorting, self._sorting)
def export_to_phy(waveform_extractor, output_folder, compute_pc_features=True,
                  compute_amplitudes=True, sparsity_dict=None, copy_binary=True,
                  max_channels_per_template=16, remove_if_exists=False,
                  peak_sign='neg', template_mode='median',
                  dtype=None, verbose=True, **job_kwargs):
    """
    Exports a waveform extractor to the phy template-gui format.

    Parameters
    ----------
    waveform_extractor: a WaveformExtractor or None
        If WaveformExtractor is provide then the compute is faster otherwise
    output_folder: str
        The output folder where the phy template-gui files are saved
    compute_pc_features: bool
        If True (default), pc features are computed
    compute_amplitudes: bool
        If True (default), waveforms amplitudes are computed
    sparsity_dict: dict or None
        If given, the dictionary should contain a sparsity method (e.g. "best_channels") and optionally
        arguments associated with the method (e.g. "num_channels" for "best_channels" method).
        Other examples are:
           * by radius: sparsity_dict=dict(method="radius", radius_um=100)
           * by SNR threshold: sparsity_dict=dict(method="threshold", threshold=2)
           * by property: sparsity_dict=dict(method="by_property", by_property="group")
        Default is sparsity_dict=dict(method="best_channels", num_channels=16)
        For more info, see the toolkit.get_template_channel_sparsity() function.
    max_channels_per_template: int or None
        Maximum channels per unit to return. If None, all channels are returned
    copy_binary: bool
        If True, the recording is copied and saved in the phy 'output_folder'
    remove_if_exists: bool
        If True and 'output_folder' exists, it is removed and overwritten
    peak_sign: 'neg', 'pos', 'both'
        Used by compute_spike_amplitudes
    template_mode: str
        Parameter 'mode' to be given to WaveformExtractor.get_template()
    dtype: dtype or None
        Dtype to save binary data
    verbose: bool
        If True, output is verbose
    {}
    """
    assert isinstance(waveform_extractor, spikeinterface.core.waveform_extractor.WaveformExtractor), \
        'waveform_extractor must be a WaveformExtractor object'
    recording = waveform_extractor.recording
    sorting = waveform_extractor.sorting

    assert recording.get_num_segments() == sorting.get_num_segments(), \
        "The recording and sorting objects must have the same number of segments!"

    assert recording.get_num_segments() == 1, "Export to phy work only with one segment"

    if sparsity_dict is None:
        sparsity_dict = dict(method="best_channels", num_channels=16)

    channel_ids = recording.channel_ids
    num_chans = recording.get_num_channels()
    fs = recording.get_sampling_frequency()

    if max_channels_per_template is None:
        max_channels_per_template = num_chans

    empty_flag = False
    non_empty_units = []
    for unit in sorting.get_unit_ids():
        if len(sorting.get_unit_spike_train(unit)) > 0:
            non_empty_units.append(unit)
        else:
            empty_flag = True
    unit_ids = non_empty_units
    if empty_flag:
        print('Warning: empty units have been removed when being exported to Phy')

    if not recording.is_filtered():
        print("Warning: recording is not filtered! It's recommended to filter the recording before exporting to phy.\n"
              "You can run spikeinterface.toolkit.preprocessing.bandpass_filter(recording)")

    if len(unit_ids) == 0:
        raise Exception("No non-empty units in the sorting result, can't save to Phy.")

    output_folder = Path(output_folder).absolute()
    if output_folder.is_dir():
        if remove_if_exists:
            shutil.rmtree(output_folder)
        else:
            raise FileExistsError(f'{output_folder} already exists')

    output_folder.mkdir()

    # save dat file
    if dtype is None:
        dtype = recording.get_dtype()

    if copy_binary:
        rec_path = output_folder / 'recording.dat'
        write_binary_recording(recording, file_paths=rec_path, verbose=verbose, dtype=dtype, **job_kwargs)
    elif isinstance(recording, BinaryRecordingExtractor):
        rec_path = recording._kwargs['file_paths'][0]
        dtype = recording.get_dtype()
    else:  # don't save recording.dat
        rec_path = 'None'

    dtype_str = np.dtype(dtype).name

    # write params.py
    with (output_folder / 'params.py').open('w') as f:
        f.write(f"dat_path = r'{str(rec_path)}'\n")
        f.write(f"n_channels_dat = {num_chans}\n")
        f.write(f"dtype = '{dtype_str}'\n")
        f.write(f"offset = 0\n")
        f.write(f"sample_rate = {fs}\n")
        f.write(f"hp_filtered = {recording.is_filtered()}")

    # export spike_times/spike_templates/spike_clusters
    # here spike_labels is a remapping to unit_index
    all_spikes = sorting.get_all_spike_trains(outputs='unit_index')
    spike_times, spike_labels = all_spikes[0]
    np.save(str(output_folder / 'spike_times.npy'), spike_times[:, np.newaxis])
    np.save(str(output_folder / 'spike_templates.npy'), spike_labels[:, np.newaxis])
    np.save(str(output_folder / 'spike_clusters.npy'), spike_labels[:, np.newaxis])

    # export templates/templates_ind/similar_templates
    # shape (num_units, num_samples, num_channels)
    templates = []
    templates_ind = []

    template_sparsity = get_template_channel_sparsity(waveform_extractor,
                                                      outputs="id", **sparsity_dict)
    num_sparse_chans = np.max([len(channels) for channels in template_sparsity.values()])
    num_channels = np.min([max_channels_per_template, num_sparse_chans])
    for unit_id in unit_ids:
        template = waveform_extractor.get_template(unit_id, mode=template_mode, sparsity=template_sparsity)
        inds = waveform_extractor.recording.ids_to_indices(template_sparsity[unit_id])
        if template.shape[-1] < num_channels:
            # fill missing channels
            template_full = np.zeros((template.shape[0], num_channels))
            template_full[:, :template.shape[-1]] = template
            inds_full = np.concatenate((inds, np.array([-1] * (num_channels - template.shape[-1]))))
        else:
            template_full = template
            inds_full = inds
        templates.append(template_full)
        templates_ind.append(inds_full)

    template_similarity = compute_template_similarity(waveform_extractor, method='cosine_similarity')

    np.save(str(output_folder / 'templates.npy'), templates)
    np.save(str(output_folder / 'template_ind.npy'), templates_ind)
    np.save(str(output_folder / 'similar_templates.npy'), template_similarity)

    channel_maps = np.arange(num_chans, dtype='int32')
    channel_map_si = waveform_extractor.recording.get_channel_ids()
    channel_positions = recording.get_channel_locations().astype('float32')
    channel_groups = recording.get_channel_groups()
    if channel_groups is None:
        channel_groups = np.zeros(num_chans, dtype='int32')
    np.save(str(output_folder / 'channel_map.npy'), channel_maps)
    np.save(str(output_folder / 'channel_map_si.npy'), channel_map_si)
    np.save(str(output_folder / 'channel_positions.npy'), channel_positions)
    np.save(str(output_folder / 'channel_groups.npy'), channel_groups)

    if compute_amplitudes:
        if waveform_extractor.is_extension('spike_amplitudes'):
            sac = waveform_extractor.load_extension('spike_amplitudes')
            amplitudes = sac.get_amplitudes(outputs='concatenated')
        else:
            amplitudes = compute_spike_amplitudes(waveform_extractor, peak_sign=peak_sign, outputs='concatenated', 
                                                  **job_kwargs)
        # one segment only
        amplitudes = amplitudes[0][:, np.newaxis]
        np.save(str(output_folder / 'amplitudes.npy'), amplitudes)

    if compute_pc_features:
        if waveform_extractor.is_extension('principal_components'):
            pc = waveform_extractor.load_extension('principal_components')
        else:
            pc = compute_principal_components(waveform_extractor, n_components=5, mode='by_channel_local')
        
        max_channels_per_template = min(max_channels_per_template, len(channel_ids))
        pc.run_for_all_spikes(output_folder / 'pc_features.npy',
                              max_channels_per_template=max_channels_per_template, peak_sign=peak_sign,
                              **job_kwargs)

        pc_feature_ind = np.zeros((len(unit_ids), max_channels_per_template), dtype='int64')
        best_channels_index = get_template_channel_sparsity(waveform_extractor, method='best_channels',
                                                            peak_sign=peak_sign, num_channels=max_channels_per_template,
                                                            outputs='index')
        for u, unit_id in enumerate(sorting.unit_ids):
            pc_feature_ind[u, :] = best_channels_index[unit_id]
        np.save(str(output_folder / 'pc_feature_ind.npy'), pc_feature_ind)

    # Save .tsv metadata
    unit_ids = sorting.unit_ids
    cluster_group = pd.DataFrame({'cluster_id': [i for i in range(len(unit_ids))],
                                  'group': ['unsorted'] * len(unit_ids)})
    cluster_group.to_csv(output_folder / 'cluster_group.tsv',
                         sep="\t", index=False)
    si_unit_ids = pd.DataFrame({'cluster_id': [i for i in range(len(unit_ids))],
                                'si_unit_id': unit_ids})
    si_unit_ids.to_csv(output_folder / 'cluster_si_unit_ids.tsv',
                       sep="\t", index=False)

    unit_groups = sorting.get_property('group')
    if unit_groups is None:
        unit_groups = np.zeros(len(unit_ids), dtype='int32')
    channel_group = pd.DataFrame({'cluster_id': [i for i in range(len(unit_ids))],
                                  'channel_group': unit_groups})
    channel_group.to_csv(output_folder / 'cluster_channel_group.tsv',
                         sep="\t", index=False)
    
    if waveform_extractor.is_extension('quality_metrics'):
        qm = waveform_extractor.load_extension('quality_metrics')
        qm_data = qm.get_metrics()
        for column_name in qm_data.columns:
            # already computed by phy
            if column_name not in ["num_spikes", "firing_rate"]:
                metric = pd.DataFrame({'cluster_id': [i for i in range(len(unit_ids))],
                                       column_name: qm_data[column_name].values})
                metric.to_csv(output_folder / f'cluster_{column_name}.tsv',
                              sep="\t", index=False)

    if verbose:
        print('Run:\nphy template-gui ', str(output_folder / 'params.py'))
def export_report(waveform_extractor,
                  output_folder,
                  remove_if_exists=False,
                  format="png",
                  show_figures=False,
                  peak_sign='neg',
                  **job_kwargs):
    """
    Exports a SI spike sorting report. The report includes summary figures of the spike sorting output
    (e.g. amplitude distributions, unit localization and depth VS amplitude) as well as unit-specific reports,
    that include waveforms, templates, template maps, ISI distributions, and more.
    
    
    Parameters
    ----------
    waveform_extractor: a WaveformExtractor or None
        If WaveformExtractor is provide then the compute is faster otherwise
    output_folder: str
        The output folder where the report files are saved
    remove_if_exists: bool
        If True and the output folder exists, it is removed
    format: str
        'png' (default) or 'pdf' or any format handled by matplotlib
    peak_sign: 'neg' or 'pos'
        used to compute amplitudes and metrics
    show_figures: bool
        If True, figures are shown. If False (default), figures are closed after saving.
    {}
    """
    we = waveform_extractor
    sorting = we.sorting
    unit_ids = sorting.unit_ids

    # lets matplotlib do this check svg is also cool
    # assert format in ["png", "pdf"], "'format' can be 'png' or 'pdf'"

    if we.is_extension('spike_amplitudes'):
        sac = we.load_extension('spike_amplitudes')
        amplitudes = sac.get_amplitudes(outputs='by_unit')
    else:
        amplitudes = st.compute_spike_amplitudes(we,
                                                 peak_sign=peak_sign,
                                                 outputs='by_unit',
                                                 **job_kwargs)

    output_folder = Path(output_folder).absolute()
    if output_folder.is_dir():
        if remove_if_exists:
            shutil.rmtree(output_folder)
        else:
            raise FileExistsError(f'{output_folder} already exists')
    output_folder.mkdir(parents=True, exist_ok=True)

    # unit list
    units = pd.DataFrame(
        index=unit_ids)  #  , columns=['max_on_channel_id', 'amplitude'])
    units.index.name = 'unit_id'
    units['max_on_channel_id'] = pd.Series(
        st.get_template_extremum_channel(we, peak_sign='neg', outputs='id'))
    units['amplitude'] = pd.Series(
        st.get_template_extremum_amplitude(we, peak_sign='neg'))
    units.to_csv(output_folder / 'unit list.csv', sep='\t')

    # metrics
    if we.is_extension('quality_metrics'):
        qmc = we.load_extension('quality_metrics')
        metrics = qmc._metrics
    else:
        # compute principal_components if not done
        if not we.is_extension('principal_components'):
            pca = st.compute_principal_components(we,
                                                  load_if_exists=True,
                                                  n_components=5,
                                                  mode='by_channel_local')
        metrics = st.compute_quality_metrics(we)
    metrics.to_csv(output_folder / 'quality metrics.csv')

    unit_colors = sw.get_unit_colors(sorting)

    # global figures
    fig = plt.figure(figsize=(20, 10))
    w = sw.plot_unit_localization(we, figure=fig, unit_colors=unit_colors)
    fig.savefig(output_folder / f'unit_localization.{format}')
    if not show_figures:
        plt.close(fig)

    fig, ax = plt.subplots(figsize=(20, 10))
    sw.plot_units_depth_vs_amplitude(we, ax=ax, unit_colors=unit_colors)
    fig.savefig(output_folder / f'units_depth_vs_amplitude.{format}')
    if not show_figures:
        plt.close(fig)

    fig = plt.figure(figsize=(20, 10))
    sw.plot_amplitudes_distribution(we, figure=fig, unit_colors=unit_colors)
    fig.savefig(output_folder / f'amplitudes_distribution.{format}')
    if not show_figures:
        plt.close(fig)

    # units
    units_folder = output_folder / 'units'
    units_folder.mkdir()

    for unit_id in unit_ids:
        fig = plt.figure(
            constrained_layout=False,
            figsize=(15, 7),
        )
        sw.plot_unit_summary(we, unit_id, figure=fig)
        fig.suptitle(f'unit {unit_id}')
        fig.savefig(units_folder / f'{unit_id}.{format}')
        if not show_figures:
            plt.close(fig)
def test_compute_spike_amplitudes():
    repo = 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data'
    remote_path = 'mearec/mearec_test_10s.h5'
    local_path = download_dataset(repo=repo,
                                  remote_path=remote_path,
                                  local_folder=None)
    recording = se.MEArecRecordingExtractor(local_path)
    sorting = se.MEArecSortingExtractor(local_path)

    folder = Path('mearec_waveforms')

    we = extract_waveforms(recording,
                           sorting,
                           folder,
                           ms_before=1.,
                           ms_after=2.,
                           max_spikes_per_unit=500,
                           n_jobs=1,
                           chunk_size=30000,
                           load_if_exists=False,
                           overwrite=True)

    amplitudes = compute_spike_amplitudes(we,
                                          peak_sign='neg',
                                          outputs='concatenated',
                                          chunk_size=10000,
                                          n_jobs=1)
    amplitudes = compute_spike_amplitudes(we,
                                          peak_sign='neg',
                                          outputs='by_unit',
                                          chunk_size=10000,
                                          n_jobs=1)

    gain = 0.1
    recording.set_channel_gains(gain)
    recording.set_channel_offsets(0)

    folder = Path('mearec_waveforms_scaled')

    we_scaled = extract_waveforms(recording,
                                  sorting,
                                  folder,
                                  ms_before=1.,
                                  ms_after=2.,
                                  max_spikes_per_unit=500,
                                  n_jobs=1,
                                  chunk_size=30000,
                                  load_if_exists=False,
                                  overwrite=True,
                                  return_scaled=True)

    amplitudes_scaled = compute_spike_amplitudes(we_scaled,
                                                 peak_sign='neg',
                                                 outputs='concatenated',
                                                 chunk_size=10000,
                                                 n_jobs=1,
                                                 return_scaled=True)
    amplitudes_unscaled = compute_spike_amplitudes(we_scaled,
                                                   peak_sign='neg',
                                                   outputs='concatenated',
                                                   chunk_size=10000,
                                                   n_jobs=1,
                                                   return_scaled=False)

    assert np.allclose(amplitudes_scaled[0], amplitudes_unscaled[0] * gain)

    # reload as an extension from we
    assert SpikeAmplitudesCalculator in we.get_available_extensions()
    assert we_scaled.is_extension('spike_amplitudes')
    sac = we.load_extension('spike_amplitudes')
    assert isinstance(sac, SpikeAmplitudesCalculator)
    assert sac._amplitudes is not None
    qmc = SpikeAmplitudesCalculator.load_from_folder(folder)
    assert sac._amplitudes is not None