Exemplo n.º 1
0
        def test_convert_recording_extractor_to_nwb(self, se_class, dataset_path, se_kwargs):
            print(f"\n\n\n TESTING {se_class.extractor_name}...")
            dataset_stem = Path(dataset_path).stem
            self.dataset.get(dataset_path)
            recording = se_class(**se_kwargs)


            # # test writing to NWB
            if test_nwb:
                nwb_save_path = self.savedir / f"{se_class.__name__}_test_{dataset_stem}.nwb"
                se.NwbRecordingExtractor.write_recording(recording, nwb_save_path, write_scaled=True)
                nwb_recording = se.NwbRecordingExtractor(nwb_save_path)
                check_recordings_equal(recording, nwb_recording)

                if recording.has_unscaled:
                    nwb_save_path_unscaled = self.savedir / f"{se_class.__name__}_test_{dataset_stem}_unscaled.nwb"
                    if np.all(recording.get_channel_offsets() == 0):
                        se.NwbRecordingExtractor.write_recording(recording, nwb_save_path_unscaled, write_scaled=False)
                        nwb_recording = se.NwbRecordingExtractor(nwb_save_path_unscaled)
                        check_recordings_equal(recording, nwb_recording, return_scaled=False)
                        # Skip check when NWB converts uint to int
                        if recording.get_dtype(return_scaled=False) == nwb_recording.get_dtype(return_scaled=False):
                            check_recordings_equal(recording, nwb_recording, return_scaled=True)

            # test caching
            if test_caching:
                rec_cache = se.CacheRecordingExtractor(recording)
                check_recordings_equal(recording, rec_cache)
                if recording.has_unscaled:
                    rec_cache_unscaled = se.CacheRecordingExtractor(recording, return_scaled=False)
                    check_recordings_equal(recording, rec_cache_unscaled, return_scaled=False)
                    check_recordings_equal(recording, rec_cache_unscaled, return_scaled=True)
Exemplo n.º 2
0
    def test_cache_extractor(self):
        cache_rec = se.CacheRecordingExtractor(self.RX)
        check_recording_return_types(cache_rec)
        check_recordings_equal(self.RX, cache_rec)
        cache_rec.move_to('cache_rec')

        assert cache_rec.filename == 'cache_rec.dat'
        check_dumping(cache_rec)

        cache_rec = se.CacheRecordingExtractor(self.RX, save_path='cache_rec2')
        check_recording_return_types(cache_rec)
        check_recordings_equal(self.RX, cache_rec)

        assert cache_rec.filename == 'cache_rec2.dat'
        check_dumping(cache_rec)

        # test saving to file
        del cache_rec
        assert Path('cache_rec2.dat').is_file()

        # test tmp
        cache_rec = se.CacheRecordingExtractor(self.RX)
        tmp_file = cache_rec.filename
        del cache_rec
        assert not Path(tmp_file).is_file()

        cache_sort = se.CacheSortingExtractor(self.SX)
        check_sorting_return_types(cache_sort)
        check_sortings_equal(self.SX, cache_sort)
        cache_sort.move_to('cache_sort')

        assert cache_sort.filename == 'cache_sort.npz'
        check_dumping(cache_sort)

        # test saving to file
        del cache_sort
        assert Path('cache_sort.npz').is_file()

        cache_sort = se.CacheSortingExtractor(self.SX, save_path='cache_sort2')
        check_sorting_return_types(cache_sort)
        check_sortings_equal(self.SX, cache_sort)

        assert cache_sort.filename == 'cache_sort2.npz'
        check_dumping(cache_sort)

        # test saving to file
        del cache_sort
        assert Path('cache_sort2.npz').is_file()

        # test tmp
        cache_sort = se.CacheSortingExtractor(self.SX)
        tmp_file = cache_sort.filename
        del cache_sort
        assert not Path(tmp_file).is_file()

        # cleanup
        os.remove('cache_rec.dat')
        os.remove('cache_rec2.dat')
        os.remove('cache_sort.npz')
        os.remove('cache_sort2.npz')
def save_si_object(object_name: str, si_object, output_folder,
                   cache_raw=False, include_properties=True, include_features=False):
    """
    Save an arbitrary SI object to a temprary location for NWB conversion.

    Parameters
    ----------
    object_name: str
        The unique name of the SpikeInterface object.
    si_object: RecordingExtractor or SortingExtractor
        The extractor to be saved.
    output_folder: str or Path
        The folder where the object is saved.
    cache_raw: bool
        If True, the Extractor is cached to a binary file (not recommended for RecordingExtractor objects)
        (default False).
    include_properties: bool
        If True, properties (channel or unit) are saved (default True).
    include_features: bool
        If True, spike features are saved (default False)
    """
    Path(output_folder).mkdir(parents=True, exist_ok=True)

    if isinstance(si_object, se.RecordingExtractor):
        if not si_object.is_dumpable:
            cache = se.CacheRecordingExtractor(si_object, save_path=output_folder / "raw.dat")
        elif cache_raw:
            # save to json before caching to keep history (in case it's needed)
            json_file = output_folder / f"{object_name}.json"
            si_object.dump_to_json(output_folder / json_file)
            cache = se.CacheRecordingExtractor(si_object, save_path=output_folder / "raw.dat")
        else:
            cache = si_object

    elif isinstance(si_object, se.SortingExtractor):
        if not si_object.is_dumpable:
            cache = se.CacheSortingExtractor(si_object, save_path=output_folder / "sorting.npz")
        elif cache_raw:
            # save to json before caching to keep history (in case it's needed)
            json_file = output_folder / f"{object_name}.json"
            si_object.dump_to_json(output_folder / json_file)
            cache = se.CacheSortingExtractor(si_object, save_path=output_folder / "sorting.npz")
        else:
            cache = si_object
    else:
        raise ValueError("The 'si_object' argument shoulde be a SpikeInterface Extractor!")

    pkl_file = output_folder / f"{object_name}.pkl"
    cache.dump_to_pickle(
        output_folder / pkl_file,
        include_properties=include_properties,
        include_features=include_features
    )
Exemplo n.º 4
0
def bandpass_filter(recording,
                    freq_min=300,
                    freq_max=6000,
                    freq_wid=1000,
                    filter_type='fft',
                    order=3,
                    chunk_size=30000,
                    cache_to_file=False,
                    cache_chunks=False,
                    dtype=None):
    '''
    Performs a lazy filter on the recording extractor traces.

    Parameters
    ----------
    recording: RecordingExtractor
        The recording extractor to be filtered.
    freq_min: int or float
        High-pass cutoff frequency.
    freq_max: int or float
        Low-pass cutoff frequency.
    freq_wid: int or float
        Width of the filter (when type is 'fft').
    filter_type: str
        'fft' or 'butter'. The 'fft' filter uses a kernel in the frequency domain. The 'butter' filter uses
        scipy butter and filtfilt functions.
    order: int
        Order of the filter (if 'butter').
    chunk_size: int
        The chunk size to be used for the filtering.
    cache_to_file: bool (default False).
        If True, filtered traces are computed and cached all at once on disk in temp file 
    cache_chunks: bool (default False).
        If True then each chunk is cached in memory (in a dict)
    dtype: dtype
        The dtype of the traces

    Returns
    -------
    filter_recording: BandpassFilterRecording
        The filtered recording extractor object
    '''
    if cache_to_file:
        assert not cache_chunks, 'if cache_to_file cache_chunks should be False'

    bpf_recording = BandpassFilterRecording(recording=recording,
                                            freq_min=freq_min,
                                            freq_max=freq_max,
                                            freq_wid=freq_wid,
                                            filter_type=filter_type,
                                            order=order,
                                            chunk_size=chunk_size,
                                            cache_chunks=cache_chunks,
                                            dtype=dtype)
    if cache_to_file:
        return se.CacheRecordingExtractor(bpf_recording, chunk_size=chunk_size)
    else:
        return bpf_recording
Exemplo n.º 5
0
    def test_cache_extractor(self):
        cache_extractor = se.CacheRecordingExtractor(self.RX)
        self._check_recording_return_types(cache_extractor)
        self._check_recordings_equal(self.RX, cache_extractor)
        cache_extractor.save_to_file('cache')

        assert cache_extractor.get_filename() == 'cache.dat'
        del cache_extractor
        assert not Path('cache.dat').is_file()
Exemplo n.º 6
0
    def test_cache_extractor(self):
        cache_extractor = se.CacheRecordingExtractor(self.RX)
        check_recording_return_types(cache_extractor)
        check_recordings_equal(self.RX, cache_extractor)
        cache_extractor.save_to_file('cache')

        assert cache_extractor.filename == 'cache.dat'
        check_dumping(cache_extractor)

        # test saving to file
        del cache_extractor
        assert Path('cache.dat').is_file()

        # test tmp
        cache_extractor = se.CacheRecordingExtractor(self.RX)
        tmp_file = cache_extractor.filename
        del cache_extractor
        assert not Path(tmp_file).is_file()
Exemplo n.º 7
0
def notch_filter(recording, freq=3000, q=30, chunk_size=30000, cache_to_file=False, cache_chunks=False):
    '''
    Performs a notch filter on the recording extractor traces using scipy iirnotch function.

    Parameters
    ----------
    recording: RecordingExtractor
        The recording extractor to be notch-filtered.
    freq: int or float
        The target frequency of the notch filter.
    q: int
        The quality factor of the notch filter.
    chunk_size: int
        The chunk size to be used for the filtering.
    cache_to_file: bool (default False).
        If True, filtered traces are computed and cached all at once on disk in temp file 
    cache_chunks: bool (default False).
        If True then each chunk is cached in memory (in a dict)
    Returns
    -------
    filter_recording: NotchFilterRecording
        The notch-filtered recording extractor object
    '''

    if cache_to_file:
        assert not cache_chunks, 'if cache_to_file cache_chunks should be False'
    
    notch_recording =  NotchFilterRecording(
        recording=recording,
        freq=freq,
        q=q,
        chunk_size=chunk_size,
        cache_chunks=cache_chunks,
    )
    if cache_to_file:
        return se.CacheRecordingExtractor(notch_recording, chunk_size=chunk_size)
    else:
        return notch_recording
    def make(self, key):
        key['analysis_file_name'] = AnalysisNwbfile().create(key['nwb_file_name'])
        # get the valid times. 
        # NOTE: we will sort independently between each entry in the valid times list
        sort_intervals =  (SortIntervalList() & {'nwb_file_name' : key['nwb_file_name'],
                                        'sort_interval_list_name' : key['sort_interval_list_name']})\
                                            .fetch1('sort_intervals')
        interval_list_name = (SpikeSortingParameters() & key).fetch1('interval_list_name')
        valid_times =  (IntervalList() & {'nwb_file_name' : key['nwb_file_name'],
                                        'interval_list_name' : interval_list_name})\
                                            .fetch('valid_times')[0]   
        raw_data_obj = (Raw() & {'nwb_file_name' : key['nwb_file_name']}).fetch_nwb()[0]['raw']
        timestamps = np.asarray(raw_data_obj.timestamps)
        sampling_rate = estimate_sampling_rate(timestamps[0:100000], 1.5)

        units = dict()
        units_valid_times = dict()
        units_sort_interval = dict()
        units_templates = dict()
        units_waveforms = dict()
        # we will add an offset to the unit_id for each sort interval to avoid duplicating ids
        unit_id_offset = 0
        #interate through the arrays of sort intervals, sorting each interval separately
        for sort_interval in sort_intervals:
               # Get the list of valid times for this sort interval
            recording_extractor, sort_interval_valid_times = self.get_recording_extractor(key, sort_interval)
            sort_parameters = (SpikeSorterParameters() & {'sorter_name': key['sorter_name'],
                                                        'parameter_set_name': key['parameter_set_name']}).fetch1()
            # get a name for the recording extractor for this sort interval
            recording_extractor_path = os.path.join(os.environ['SPIKE_SORTING_STORAGE_DIR'], 
                                                    key['analysis_file_name'], '_', str(sort_interval))
            recording_extractor_cached = se.CacheRecordingExtractor(recording_extractor, save_path=recording_extractor_path)
            print(f'Sorting {key}...')
            sort = si.sorters.run_mountainsort4(recording=recording_extractor_cached, 
                                                **sort_parameters['parameter_dict'], 
                                                grouping_property='group', 
                                                output_folder=os.getenv('SORTING_TEMP_DIR', None))
            # create a stack of labelled arrays of the sorted spike times
            timestamps = np.asarray(raw_data_obj.timestamps)
            unit_ids = sort.get_unit_ids()
            # get the waveforms; we may want to specifiy these parameters more flexibly in the future
            waveform_params = st.postprocessing.get_waveforms_params()
            print(sort_parameters)
            waveform_params['grouping_property'] = 'group'
            # set the window to half of the clip size before and half after
            waveform_params['ms_before'] = sort_parameters['parameter_dict']['clip_size'] / sampling_rate * 1000 / 2
            waveform_params['ms_after'] = waveform_params['ms_before'] 
            waveform_params['max_spikes_per_unit'] = 1000
            waveform_params['dtype'] = 'i2'
            #template_params['n_jobs'] = 7
            waveform_params['verbose'] = False
            #print(f'template_params: {template_params}')
            templates = st.postprocessing.get_unit_templates(recording_extractor_cached, sort, **waveform_params)
            # for the waveforms themselves we only need to change the max_spikes_per_unit:
            waveform_params['max_spikes_per_unit'] = 1e100
            waveforms = st.postprocessing.get_unit_waveforms(recording_extractor_cached, sort, unit_ids, **waveform_params)

            for index, unit_id in enumerate(unit_ids):
                current_index = unit_id + unit_id_offset
                unit_spike_samples = sort.get_unit_spike_train(unit_id=unit_id)  
                #print(f'template for {unit_id}: {unit_templates[unit_id]} ')
                units[current_index] = timestamps[unit_spike_samples]
                # the templates are zero based, so we have to use the index here. 
                units_templates[current_index] = templates[index]
                units_waveforms[current_index] = waveforms[index]
                units_valid_times[current_index] = sort_interval_valid_times
                units_sort_interval[current_index] = [sort_interval]
            if len(unit_ids) > 0:
                unit_id_offset += np.max(unit_ids) + 1
        
        #Add the units to the Analysis file       
        # TODO: consider replacing with spikeinterface call if possible 
        units_object_id, units_waveforms_object_id = AnalysisNwbfile().add_units(key['analysis_file_name'], units, units_templates, units_valid_times,
                                                              units_sort_interval, units_waveforms=units_waveforms)
        key['units_object_id'] = units_object_id
        key['units_waveforms_object_id'] = units_waveforms_object_id
        self.insert1(key)
Exemplo n.º 9
0
def process_openephys(project, action_id, probe_path, sorter, acquisition_folder=None,
                      exdir_file_path=None, spikesort=True, compute_lfp=True, compute_mua=False, parallel=False,
                      spikesorter_params=None, server=None, bad_channels=None, ref=None, split=None, sort_by=None,
                      ms_before_wf=1, ms_after_wf=2, bad_threshold=2, firing_rate_threshold=0,
                      isi_viol_threshold=0):
    import spikeextractors as se
    import spiketoolkit as st
    import spikesorters as ss
    bad_channels = bad_channels or []
    proc_start = time.time()

    if server is None or server == 'local':
        if acquisition_folder is None:
            action = project.actions[action_id]
            # if exdir_path is None:
            exdir_path = _get_data_path(action)
            exdir_file = exdir.File(exdir_path, plugins=exdir.plugins.quantities)
            acquisition = exdir_file["acquisition"]
            if acquisition.attrs['acquisition_system'] is None:
                raise ValueError('No Open Ephys aquisition system ' +
                                 'related to this action')
            openephys_session = acquisition.attrs["session"]
            openephys_path = Path(acquisition.directory) / openephys_session
        else:
            openephys_path = Path(acquisition_folder)
            assert exdir_file_path is not None
            exdir_path = Path(exdir_file_path)

        probe_path = probe_path or project.config.get('probe')
        recording = se.OpenEphysRecordingExtractor(str(openephys_path))
        recording = recording.load_probe_file(probe_path)

        if 'auto' not in bad_channels and len(bad_channels) > 0:
            recording_active = st.preprocessing.remove_bad_channels(recording, bad_channel_ids=bad_channels)
        else:
            recording_active = recording

        # apply filtering and cmr
        print('Writing filtered and common referenced data')

        tmp_folder = Path(f"tmp_{action_id}_si")

        freq_min_hp = 300
        freq_max_hp = 3000
        freq_min_lfp = 1
        freq_max_lfp = 300
        freq_resample_lfp = 1000
        freq_resample_mua = 1000
        type_hp = 'butter'
        order_hp = 5

        recording_hp = st.preprocessing.bandpass_filter(
            recording_active, freq_min=freq_min_hp, freq_max=freq_max_hp,
            filter_type=type_hp, order=order_hp)

        if ref is not None:
            if ref.lower() == 'cmr':
                reference = 'median'
            elif ref.lower() == 'car':
                reference = 'average'
            else:
                raise Exception("'reference' can be either 'cmr' or 'car'")
            if split == 'all':
                recording_cmr = st.preprocessing.common_reference(recording_hp, reference=reference)
            elif split == 'half':
                groups = [recording.get_channel_ids()[:int(len(recording.get_channel_ids()) / 2)],
                          recording.get_channel_ids()[int(len(recording.get_channel_ids()) / 2):]]
                recording_cmr = st.preprocessing.common_reference(recording_hp, groups=groups, reference=reference)
            else:
                if isinstance(split, list):
                    recording_cmr = st.preprocessing.common_reference(recording_hp, groups=split, reference=reference)
                else:
                    raise Exception("'split' must be a list of lists")
        else:
            recording_cmr = recording

        if 'auto' in bad_channels:
            recording_cmr = st.preprocessing.remove_bad_channels(recording_cmr, bad_channel_ids=None,
                                                                 bad_threshold=bad_threshold, seconds=10)
            recording_active = se.SubRecordingExtractor(
                recording, channel_ids=recording_cmr.active_channels)

        print("Active channels: ", len(recording_active.get_channel_ids()))
        recording_lfp = st.preprocessing.bandpass_filter(
            recording_active, freq_min=freq_min_lfp, freq_max=freq_max_lfp)
        recording_lfp = st.preprocessing.resample(
            recording_lfp, freq_resample_lfp)
        recording_mua = st.preprocessing.resample(
            st.preprocessing.rectify(recording_active), freq_resample_mua)

        if spikesort:
            print('Bandpass filter')
            t_start = time.time()
            recording_cmr = se.CacheRecordingExtractor(recording_cmr, save_path=tmp_folder / 'filt.dat')
            print('Filter time: ', time.time() - t_start)
        if compute_lfp:
            print('Computing LFP')
            t_start = time.time()
            recording_lfp = se.CacheRecordingExtractor(recording_lfp, save_path=tmp_folder / 'lfp.dat')
            print('Filter time: ', time.time() - t_start)

        if compute_mua:
            print('Computing MUA')
            t_start = time.time()
            recording_mua = se.CacheRecordingExtractor(recording_mua, save_path=tmp_folder / 'mua.dat')
            print('Filter time: ', time.time() - t_start)

        print('Number of channels', recording_cmr.get_num_channels())

        if spikesort:
            try:
                # save attributes
                exdir_group = exdir.File(exdir_path, plugins=exdir.plugins.quantities)
                ephys = exdir_group.require_group('processing').require_group('electrophysiology')
                spikesorting = ephys.require_group('spikesorting')
                sorting_group = spikesorting.require_group(sorter)
                output_folder = sorting_group.require_raw('output').directory
                if 'kilosort' in sorter:
                    sorting = ss.run_sorter(sorter, recording_cmr,
                                            parallel=parallel, verbose=True,
                                            delete_output_folder=True, **spikesorter_params)
                else:
                    sorting = ss.run_sorter(
                        sorter, recording_cmr, parallel=parallel,
                        grouping_property=sort_by, verbose=True, output_folder=output_folder,
                        delete_output_folder=True, **spikesorter_params)
                spike_sorting_attrs = {'name': sorter, 'params': spikesorter_params}
                filter_attrs = {'hp_filter': {'low': freq_min_hp, 'high': freq_max_hp},
                                'lfp_filter': {'low': freq_min_lfp, 'high': freq_max_lfp,
                                               'resample': freq_resample_lfp},
                                'mua_filter': {'resample': freq_resample_mua}}
                reference_attrs = {'type': str(ref), 'split': str(split)}
                sorting_group.attrs.update({'spike_sorting': spike_sorting_attrs,
                                            'filter': filter_attrs,
                                            'reference': reference_attrs})
            except Exception as e:
                try:
                    shutil.rmtree(tmp_folder)
                except:
                    print(f'Could not tmp processing folder: {tmp_folder}')
                raise Exception("Spike sorting failed")
            print('Found ', len(sorting.get_unit_ids()), ' units!')

        # extract waveforms
        if spikesort:
            # se.ExdirSortingExtractor.write_sorting(
            #     sorting, exdir_path, recording=recording_cmr, verbose=True)
            print('Saving Phy output')
            phy_folder = sorting_group.require_raw('phy').directory
            if firing_rate_threshold > 0:
                sorting_min = st.curation.threshold_firing_rates(sorting,
                                                                 threshold=firing_rate_threshold,
                                                                 threshold_sign='less',
                                                                 duration_in_frames=recording_cmr.get_num_frames())
                print("Removed ", (len(sorting.get_unit_ids()) - len(sorting_min.get_unit_ids())),
                      'units with less than',
                      firing_rate_threshold, 'firing rate')
                sorting_min = sorting
            if isi_viol_threshold > 0:
                sorting_viol = st.curation.threshold_isi_violations(sorting_min,
                                                                    threshold=isi_viol_threshold,
                                                                    threshold_sign='greater',
                                                                    duration_in_frames=recording_cmr.get_num_frames())
                print("Removed ", (len(sorting_min.get_unit_ids()) - len(sorting_viol.get_unit_ids())),
                      'units with ISI violation greater than', isi_viol_threshold)
            else:
                sorting_viol = sorting_min
            t_start_save = time.time()
            sorting_viol.set_tmp_folder(tmp_folder)
            st.postprocessing.export_to_phy(recording_cmr, sorting_viol, output_folder=phy_folder,
                                            ms_before=ms_before_wf, ms_after=ms_after_wf, verbose=True,
                                            grouping_property=sort_by, recompute_info=True,
                                            save_as_property_or_feature=True)
            print('Save to phy time:', time.time() - t_start_save)
        if compute_lfp:
            print('Saving LFP to exdir format')
            se.ExdirRecordingExtractor.write_recording(
                recording_lfp, exdir_path, lfp=True)
        if compute_mua:
            print('Saving MUA to exdir format')
            se.ExdirRecordingExtractor.write_recording(
                recording_mua, exdir_path, mua=True)

        # save attributes
        exdir_group = exdir.File(exdir_path, plugins=exdir.plugins.quantities)
        ephys = exdir_group.require_group('processing').require_group('electrophysiology')
        spike_sorting_attrs = {'name': sorter, 'params': spikesorter_params}
        filter_attrs = {'hp_filter': {'low': freq_min_hp, 'high': freq_max_hp},
                        'lfp_filter': {'low': freq_min_lfp, 'high': freq_max_lfp, 'resample': freq_resample_lfp},
                        'mua_filter': {'resample': freq_resample_mua}}
        reference_attrs = {'type': str(ref), 'split': str(split)}
        ephys.attrs.update({'spike_sorting': spike_sorting_attrs,
                            'filter': filter_attrs,
                            'reference': reference_attrs})

        try:
            if spikesort:
                del recording_cmr
            if compute_lfp:
                del recording_lfp
            if compute_mua:
                del recording_mua
            shutil.rmtree(tmp_folder)
        except:
            print(f'Could not tmp processing folder: {tmp_folder}')
    else:
        config = expipe.config._load_config_by_name(None)
        assert server in [s['host'] for s in config.get('servers')]
        server_dict = [s for s in config.get('servers') if s['host'] == server][0]
        host = server_dict['domain']
        user = server_dict['user']
        password = server_dict['password']
        port = 22

        # host, user, pas, port = utils.get_login(
        #     hostname=hostname, username=username, port=port, password=password)
        ssh, scp_client, sftp_client, pbar = utils.login(
            hostname=host, username=user, password=password, port=port)
        print('Invoking remote shell')
        remote_shell = utils.ShellHandler(ssh)

        ########################## SEND  #######################################
        action = project.actions[action_id]
        # if exdir_path is None:
        exdir_path = _get_data_path(action)
        exdir_path_str = str(exdir_path)
        exdir_file = exdir.File(exdir_path, plugins=exdir.plugins.quantities)
        acquisition = exdir_file["acquisition"]
        if acquisition.attrs['acquisition_system'] is None:
            raise ValueError('No Open Ephys aquisition system ' +
                             'related to this action')
        openephys_session = acquisition.attrs["session"]
        openephys_path = Path(acquisition.directory) / openephys_session
        print('Initializing transfer of "' + str(openephys_path) + '" to "' +
              host + '"')

        try:  # make directory for untaring
            process_folder = '/tmp/process_' + str(np.random.randint(10000000))
            stdin, stdout, stderr = remote_shell.execute('mkdir ' + process_folder)
        except IOError:
            pass
        print('Packing tar archive')
        remote_acq = process_folder + '/acquisition'
        remote_tar = process_folder + '/acquisition.tar'

        # transfer acquisition folder
        local_tar = shutil.make_archive(str(openephys_path), 'tar', str(openephys_path))
        print(local_tar)
        scp_client.put(
            local_tar, remote_tar, recursive=False)

        # transfer probe_file
        remote_probe = process_folder + '/probe.prb'
        scp_client.put(
            probe_path, remote_probe, recursive=False)

        remote_exdir = process_folder + '/main.exdir'
        remote_proc = process_folder + '/main.exdir/processing'
        remote_proc_tar = process_folder + '/processing.tar'
        local_proc = str(exdir_path / 'processing')
        local_proc_tar = local_proc + '.tar'

        # transfer spike params
        if spikesorter_params is not None:
            spike_params_file = 'spike_params.yaml'
            with open(spike_params_file, 'w') as f:
                yaml.dump(spikesorter_params, f)
            remote_yaml = process_folder + '/' + spike_params_file
            scp_client.put(
                spike_params_file, remote_yaml, recursive=False)
            try:
                os.remove(spike_params_file)
            except:
                print('Could not remove: ', spike_params_file)
        else:
            remote_yaml = 'none'

        extra_args = ""
        if not compute_lfp:
            extra_args = extra_args + ' --no-lfp'
        if not compute_mua:
            extra_args = extra_args + ' --no-mua'
        if not spikesort:
            extra_args = extra_args + ' --no-sorting'
        extra_args = extra_args + ' -bt {}'.format(bad_threshold)

        if ref is not None and isinstance(ref, str):
            ref = ref.lower()
        if split is not None and isinstance(split, str):
            split = split.lower()

        bad_channels_cmd = ''
        for bc in bad_channels:
            bad_channels_cmd = bad_channels_cmd + ' -bc ' + str(bc)

        ref_cmd = ''
        if ref is not None:
            ref_cmd = ' --ref ' + ref.lower()

        split_cmd = ''
        if split is not None:
            split_cmd = ' --split-channels ' + str(split)

        par_cmd = ''
        if not parallel:
            par_cmd = ' --no-par '

        sortby_cmd = ''
        if sort_by is not None:
            sortby_cmd = ' --sort-by ' + sort_by

        wf_cmd = ' --ms-before-wf ' + str(ms_before_wf) + ' --ms-after-wf ' + str(ms_after_wf)

        ms_cmd = ' --min-fr ' + str(firing_rate_threshold)

        isi_cmd = ' --min-isi ' + str(isi_viol_threshold)

        try:
            pbar[0].close()
        except Exception:
            pass

        print('Making acquisition folder')
        cmd = "mkdir " + remote_acq
        print('Shell: ', cmd)
        stdin, stdout, stderr = remote_shell.execute("mkdir " + remote_acq)
        # utils.ssh_execute(ssh, "mkdir " + remote_acq)

        print('Unpacking tar archive')
        cmd = "tar -xf " + remote_tar + " --directory " + remote_acq
        stdin, stdout, stderr = remote_shell.execute(cmd)
        # utils.ssh_execute(ssh, cmd)

        print('Deleting tar archives')
        if not os.access(str(local_tar), os.W_OK):
            # Is the error an access error ?
            os.chmod(str(local_tar), stat.S_IWUSR)
        try:
            os.remove(str(local_tar))
        except:
            print('Could not remove: ', local_tar)

        ###################### PROCESS #######################################
        print('Processing on server')
        cmd = "expipe process openephys {} --probe-path {} --sorter {} --spike-params {}  " \
              "--acquisition {} --exdir-path {} {} {} {} {} {} {} {} {} {}".format(
            action_id, remote_probe, sorter, remote_yaml, remote_acq,
            remote_exdir, bad_channels_cmd, ref_cmd, par_cmd, sortby_cmd,
            split_cmd, wf_cmd, extra_args, ms_cmd, isi_cmd)

        stdin, stdout, stderr = remote_shell.execute(cmd, print_lines=True)

        print('Finished remote processing')
        ####################### RETURN PROCESSED DATA #######################
        print('Initializing transfer of "' + remote_proc + '" to "' +
              local_proc + '"')
        print('Packing tar archive')
        cmd = "tar -C " + remote_exdir + " -cf " + remote_proc_tar + ' processing'
        stdin, stdout, stderr = remote_shell.execute(cmd, print_lines=True)
        # cmd = "ls -l " + remote_proc_tar
        # stdin, stdout, stderr = remote_shell.execute(cmd, print_lines=False)
        # wait for 5 seconds to ensure that packing is done
        time.sleep(5)
        # utils.ssh_execute(ssh, "tar -C " + remote_exdir + " -cf " + remote_proc_tar + ' processing')
        scp_client.get(remote_proc_tar, local_proc_tar, recursive=False)
        try:
            pbar[0].close()
        except Exception:
            pass

        print('Unpacking tar archive')
        if 'processing' in exdir_file:
            if 'electrophysiology' in exdir_file['processing']:
                print('Merging with old processing/electrophysiology')
        with tarfile.open(str(local_proc_tar)) as tar:
            _ = [tar.extract(m, exdir_path_str) for m in tar.getmembers() if 'tracking' not in m.name and
                 'exdir.yaml' in m.name]
            if spikesort:
                _ = [tar.extract(m, exdir_path_str) for m in tar.getmembers() if
                     'spikesorting' in m.name and sorter in m.name]
            if compute_lfp:
                _ = [tar.extract(m, exdir_path_str) for m in tar.getmembers() if
                     'LFP' in m.name]
                _ = [tar.extract(m, exdir_path_str) for m in tar.getmembers() if
                     'group' in m.name and 'attributes' in m.name]
            if compute_lfp:
                _ = [tar.extract(m, exdir_path_str) for m in tar.getmembers() if
                     'MUA' in m.name]
                _ = [tar.extract(m, exdir_path_str) for m in tar.getmembers() if
                     'group' in m.name and 'attributes' in m.name]

        print('Deleting tar archives')
        if not os.access(str(local_proc_tar), os.W_OK):
            # Is the error an access error ?
            os.chmod(str(local_proc_tar), stat.S_IWUSR)
        try:
            os.remove(str(local_proc_tar))
        except:
            print('Could not remove: ', local_proc_tar)
        # sftp_client.remove(remote_proc_tar)
        print('Deleting remote process folder')
        cmd = "rm -rf " + process_folder
        stdin, stdout, stderr = remote_shell.execute(cmd)

        #################### CLOSE UP #############################
        ssh.close()
        sftp_client.close()
        scp_client.close()

    # check for tracking and events (always locally)
    oe_recording = pyopenephys.File(str(openephys_path)).experiments[0].recordings[0]
    if len(oe_recording.tracking) > 0:
        print('Saving ', len(oe_recording.tracking), ' Open Ephys tracking sources')
        generate_tracking(exdir_path, oe_recording)

    if len(oe_recording.events) > 0:
        print('Saving ', len(oe_recording.events), ' Open Ephys event sources')
        generate_events(exdir_path, oe_recording)

    print('Saved to exdir: ', exdir_path)
    print("Total elapsed time: ", time.time() - proc_start)
    print('Finished processing')
def test_waveforms():
    n_wf_samples = 100
    n_jobs = [0, 2]
    for n in n_jobs:
        for m in memmaps:
            print('N jobs', n, 'memmap', m)
            folder = 'test'
            if os.path.isdir(folder):
                shutil.rmtree(folder)
            rec, sort, waveforms, templates, max_chans, amps = create_signal_with_known_waveforms(
                n_waveforms=2, n_channels=4, n_wf_samples=n_wf_samples)
            rec, sort = create_dumpable_extractors_from_existing(
                folder, rec, sort)
            # get num samples in ms
            ms_cut = n_wf_samples // 2 / rec.get_sampling_frequency() * 1000

            # no group
            wav = get_unit_waveforms(rec,
                                     sort,
                                     ms_before=ms_cut,
                                     ms_after=ms_cut,
                                     save_property_or_features=False,
                                     n_jobs=n,
                                     memmap=m,
                                     recompute_info=True)

            for (w, w_gt) in zip(wav, waveforms):
                assert np.allclose(w, w_gt)
            assert 'waveforms' not in sort.get_shared_unit_spike_feature_names(
            )

            # small chunks
            wav = get_unit_waveforms(rec,
                                     sort,
                                     ms_before=ms_cut,
                                     ms_after=ms_cut,
                                     save_property_or_features=False,
                                     n_jobs=n,
                                     memmap=m,
                                     chunk_mb=5,
                                     recompute_info=True)

            for (w, w_gt) in zip(wav, waveforms):
                assert np.allclose(w, w_gt)
            assert 'waveforms' not in sort.get_shared_unit_spike_feature_names(
            )

            # return_scaled
            gain = 0.1
            rec_sc, sort_sc = se.example_datasets.toy_example()
            rec_sc.set_channel_gains(gain)
            rec_sc.has_unscaled = True
            rec_cache = se.CacheRecordingExtractor(rec_sc, return_scaled=False)
            wav_unscaled = get_unit_waveforms(rec_cache,
                                              sort_sc,
                                              ms_before=ms_cut,
                                              ms_after=ms_cut,
                                              save_property_or_features=False,
                                              n_jobs=n,
                                              memmap=m,
                                              return_scaled=False,
                                              recompute_info=True)
            wav_unscaled = [np.array(wf) for wf in wav_unscaled]
            wav_scaled = get_unit_waveforms(rec_cache,
                                            sort_sc,
                                            ms_before=ms_cut,
                                            ms_after=ms_cut,
                                            save_property_or_features=False,
                                            n_jobs=n,
                                            memmap=m,
                                            return_scaled=True,
                                            recompute_info=True)
            wav_scaled = [np.array(wf) for wf in wav_scaled]

            for (w_unscaled, w_scaled) in zip(wav_unscaled, wav_scaled):
                assert np.allclose(w_unscaled * gain, w_scaled)

            # change cut ms
            wav = get_unit_waveforms(rec,
                                     sort,
                                     ms_before=2,
                                     ms_after=2,
                                     save_property_or_features=True,
                                     n_jobs=n,
                                     memmap=m,
                                     recompute_info=True)

            for (w, w_gt) in zip(wav, waveforms):
                _, _, samples = w.shape
                assert np.allclose(
                    w[:, :, samples // 2 - n_wf_samples // 2:samples // 2 +
                      n_wf_samples // 2], w_gt)
            assert 'waveforms' in sort.get_shared_unit_spike_feature_names()

            # by group
            rec.set_channel_groups([0, 0, 1, 1])
            wav = get_unit_waveforms(rec,
                                     sort,
                                     ms_before=ms_cut,
                                     ms_after=ms_cut,
                                     grouping_property='group',
                                     n_jobs=n,
                                     memmap=m,
                                     recompute_info=True)

            for (w, w_gt) in zip(wav, waveforms):
                assert np.allclose(w, w_gt[:, :2]) or np.allclose(
                    w, w_gt[:, 2:])

            # test compute_property_from_recordings
            wav = get_unit_waveforms(rec,
                                     sort,
                                     ms_before=ms_cut,
                                     ms_after=ms_cut,
                                     grouping_property='group',
                                     compute_property_from_recording=True,
                                     n_jobs=n,
                                     memmap=m,
                                     recompute_info=True)
            for (w, w_gt) in zip(wav, waveforms):
                assert np.allclose(w, w_gt[:, :2]) or np.allclose(
                    w, w_gt[:, 2:])

            # test max_spikes_per_unit
            wav = get_unit_waveforms(rec,
                                     sort,
                                     ms_before=ms_cut,
                                     ms_after=ms_cut,
                                     max_spikes_per_unit=10,
                                     save_property_or_features=False,
                                     n_jobs=n,
                                     memmap=m,
                                     recompute_info=True)
            for w in wav:
                assert len(w) <= 10

            # test channels
            wav = get_unit_waveforms(rec,
                                     sort,
                                     ms_before=ms_cut,
                                     ms_after=ms_cut,
                                     channel_ids=[0, 1, 2],
                                     n_jobs=n,
                                     memmap=m,
                                     recompute_info=True)

            for (w, w_gt) in zip(wav, waveforms):
                assert np.allclose(w, w_gt[:, :3])
    shutil.rmtree('test')