Exemplo n.º 1
0
def _mountainsort4(
    recording,  # The recording extractor
    output_folder=None,
    detect_sign=-1,  # Use -1, 0, or 1, depending on the sign of the spikes in the recording
    adjacency_radius=-1,  # Use -1 to include all channels in every neighborhood
    freq_min=300,  # Use None for no bandpass filtering
    freq_max=6000,
    whiten=True,  # Whether to do channel whitening as part of preprocessing
    clip_size=50,
    detect_threshold=3,
    detect_interval=10,  # Minimum number of timepoints between events detected on the same channel
    noise_overlap_threshold=0.15  # Use None for no automated curation
):
    try:
        import ml_ms4alg
    except ModuleNotFoundError:
        raise ModuleNotFoundError(
            "\nTo use Mountainsort, install ml_ms4alg: \n\n"
            "\npip install ml_ms4alg\n"
            "\nMore information on Mountainsort at: "
            "\nhttps://github.com/flatironinstitute/mountainsort")
    # Bandpass filter
    if freq_min is not None:
        recording = st.preprocessing.bandpass_filter(recording=recording,
                                                     freq_min=freq_min,
                                                     freq_max=freq_max)

    # Whiten
    if whiten:
        recording = st.preprocessing.whiten(recording=recording)

    # Check location
    if 'location' not in recording.getChannelPropertyNames():
        for i, chan in enumerate(recording.getChannelIds()):
            recording.setChannelProperty(chan, 'location', [0, i])

    # Sort
    sorting = ml_ms4alg.mountainsort4(recording=recording,
                                      detect_sign=detect_sign,
                                      adjacency_radius=adjacency_radius,
                                      clip_size=clip_size,
                                      detect_threshold=detect_threshold,
                                      detect_interval=detect_interval)

    # Curate
    if noise_overlap_threshold is not None:
        sorting = ml_ms4alg.mountainsort4_curation(
            recording=recording,
            sorting=sorting,
            noise_overlap_threshold=noise_overlap_threshold)

    return sorting
Exemplo n.º 2
0
def mountainsort4(
        recording,  # The recording extractor
        detect_sign,  # Use -1, 0, or 1, depending on the sign of the spikes in the recording
        adjacency_radius,  # Use -1 to include all channels in every neighborhood
        freq_min=300,  # Use None for no bandpass filtering
        freq_max=6000,
        whiten=True,  # Whether to do channel whitening as part of preprocessing
        clip_size=50,
        detect_threshold=3,
        detect_interval=10,  # Minimum number of timepoints between events detected on the same channel
        noise_overlap_threshold=0.15,  # Use None for no automated curation
        num_workers=None,  # Use None for multiprocessing.cpu_count/2
):
    try:
        import ml_ms4alg
    except ModuleNotFoundError:
        raise ModuleNotFoundError(
            "\nTo use Mountainsort, install ml_ms4alg: \n\n"
            "\npip install ml_ms4alg\n"
            "\nMore information on Mountainsort at: "
            "\nhttps://github.com/flatironinstitute/mountainsort")

    t_start_proc = time.time()

    # Bandpass filter
    if freq_min is not None:
        recording = sw.lazyfilters.bandpass_filter(recording=recording,
                                                   freq_min=freq_min,
                                                   freq_max=freq_max)

    # Whiten
    if whiten:
        recording = sw.lazyfilters.whiten(recording=recording)

    # Sort
    sorting = ml_ms4alg.mountainsort4(recording=recording,
                                      detect_sign=detect_sign,
                                      adjacency_radius=adjacency_radius,
                                      clip_size=clip_size,
                                      detect_threshold=detect_threshold,
                                      detect_interval=detect_interval,
                                      num_workers=num_workers)

    # Curate
    if noise_overlap_threshold is not None:
        sorting = ml_ms4alg.mountainsort4_curation(
            recording=recording,
            sorting=sorting,
            noise_overlap_threshold=noise_overlap_threshold)
    print('Elapsed time: ', time.time() - t_start_proc)

    return sorting
Exemplo n.º 3
0
    def _run(self, recording, output_folder):
        recording = recover_recording(recording)
        # Sort
        # alias to params
        p = self.params

        if recording.is_filtered and p['filter']:
            print(
                "Warning! The recording is already filtered, but Mountainsort4 filter is enabled. You can disable "
                "filters by setting 'filter' parameter to False")

        samplerate = recording.get_sampling_frequency()

        # Bandpass filter
        if p['filter'] and p['freq_min'] is not None and p[
                'freq_max'] is not None:
            recording = bandpass_filter(recording=recording,
                                        freq_min=p['freq_min'],
                                        freq_max=p['freq_max'])

        # Whiten
        if p['whiten']:
            recording = whiten(recording=recording)

        # Check location no more needed done in basesorter

        sorting = ml_ms4alg.mountainsort4(
            recording=recording,
            detect_sign=p['detect_sign'],
            adjacency_radius=p['adjacency_radius'],
            clip_size=p['clip_size'],
            detect_threshold=p['detect_threshold'],
            detect_interval=p['detect_interval'],
            num_workers=p['num_workers'],
            verbose=self.verbose)

        # Curate
        if p['noise_overlap_threshold'] is not None and p['curation'] is True:
            if self.verbose:
                print('Curating')
            sorting = ml_ms4alg.mountainsort4_curation(
                recording=recording,
                sorting=sorting,
                noise_overlap_threshold=p['noise_overlap_threshold'])

        se.MdaSortingExtractor.write_sorting(
            sorting, str(output_folder / 'firings.mda'))

        samplerate_fname = str(output_folder / 'samplerate.txt')
        with open(samplerate_fname, 'w') as f:
            f.write('{}'.format(samplerate))
Exemplo n.º 4
0
    def _run(self, recording, output_folder):

        # Sort
        # alias to params
        p = self.params

        samplerate = recording.get_sampling_frequency()

        # Bandpass filter
        if p['filter'] and p['freq_min'] is not None and p[
                'freq_max'] is not None:
            recording = bandpass_filter(recording=recording,
                                        freq_min=p['freq_min'],
                                        freq_max=p['freq_max'])
        # Whiten
        if p['whiten']:
            recording = whiten(recording=recording)

        # Check location
        if 'location' not in recording.get_shared_channel_property_names():
            for i, chan in enumerate(recording.get_channel_ids()):
                recording.set_channel_property(chan, 'location', [0, i])

        sorting = ml_ms4alg.mountainsort4(
            recording=recording,
            detect_sign=p['detect_sign'],
            adjacency_radius=p['adjacency_radius'],
            clip_size=p['clip_size'],
            detect_threshold=p['detect_threshold'],
            detect_interval=p['detect_interval'],
            num_workers=p['num_workers'],
            verbose=self.verbose)

        # Curate
        if p['noise_overlap_threshold'] is not None and p['curation'] is True:
            if self.verbose:
                print('Curating')
            sorting = ml_ms4alg.mountainsort4_curation(
                recording=recording,
                sorting=sorting,
                noise_overlap_threshold=p['noise_overlap_threshold'])

        se.MdaSortingExtractor.write_sorting(
            sorting, str(output_folder / 'firings.mda'))

        samplerate_fname = str(output_folder / 'samplerate.txt')
        with open(samplerate_fname, 'w') as f:
            f.write('{}'.format(samplerate))
Exemplo n.º 5
0
    def run(self):
        if self.throw_error:
            import time
            print(
                'Intentionally throwing an error in 3 seconds (MountainSort4TestError)...'
            )
            sys.stdout.flush()
            time.sleep(3)
            raise Exception('Intentional error.')
        import ml_ms4alg

        print('MountainSort4......')
        recording = SFMdaRecordingExtractor(self.recording_dir)
        num_workers = os.environ.get('NUM_WORKERS', None)
        if num_workers:
            num_workers = int(num_workers)

        # Bandpass filter
        if self.freq_min or self.freq_max:
            recording = bandpass_filter(recording=recording,
                                        freq_min=self.freq_min,
                                        freq_max=self.freq_max)

        # Whiten
        if self.whiten:
            recording = whiten(recording=recording)

        # Sort
        sorting = ml_ms4alg.mountainsort4(
            recording=recording,
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            clip_size=self.clip_size,
            detect_threshold=self.detect_threshold,
            detect_interval=self.detect_interval,
            num_workers=num_workers)

        # Curate
        # if self.noise_overlap_threshold is not None:
        #    sorting=ml_ms4alg.mountainsort4_curation(
        #      recording=recording,
        #      sorting=sorting,
        #      noise_overlap_threshold=self.noise_overlap_threshold
        #    )

        SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                            save_path=self.firings_out)
Exemplo n.º 6
0
    def _run(self, recording, output_folder):
        # Sort
        # alias to params
        p = self.params
        print(p)

        ind = self.recording_list.index(recording)

        # Bandpass filter
        if p['filter'] and p['freq_min'] is not None and p[
                'freq_max'] is not None:
            recording = bandpass_filter(recording=recording,
                                        freq_min=p['freq_min'],
                                        freq_max=p['freq_max'])

        # Whiten
        if p['whiten']:
            recording = whiten(recording=recording)

        # Check location
        if 'location' not in recording.get_channel_property_names():
            for i, chan in enumerate(recording.get_channel_ids()):
                recording.set_channel_property(chan, 'location', [0, i])

        sorting = ml_ms4alg.mountainsort4(
            recording=recording,
            detect_sign=p['detect_sign'],
            adjacency_radius=p['adjacency_radius'],
            clip_size=p['clip_size'],
            detect_threshold=p['detect_threshold'],
            detect_interval=p['detect_interval'])

        # Curate
        if p['noise_overlap_threshold'] is not None and p['curation'] is True:
            print('Curating')
            sorting = ml_ms4alg.mountainsort4_curation(
                recording=recording,
                sorting=sorting,
                noise_overlap_threshold=p['noise_overlap_threshold'])

        se.MdaSortingExtractor.write_sorting(
            sorting, str(output_folder / 'firings.mda'))
Exemplo n.º 7
0
    def run(self):
        from .bandpass_filter import bandpass_filter
        from .whiten import whiten

        import ml_ms4alg

        print('MountainSort4......')
        recording = SFMdaRecordingExtractor(self.recording_dir)
        num_workers = os.environ.get('NUM_WORKERS', None)
        if num_workers:
            num_workers = int(num_workers)

        # Bandpass filter
        if self.freq_min or self.freq_max:
            recording = bandpass_filter(
                recording=recording, freq_min=self.freq_min, freq_max=self.freq_max)

        # Whiten
        if self.whiten:
            recording = whiten(recording=recording)

        # Sort
        sorting = ml_ms4alg.mountainsort4(
            recording=recording,
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            clip_size=self.clip_size,
            detect_threshold=self.detect_threshold,
            detect_interval=self.detect_interval,
            num_workers=num_workers
        )

        # Curate
        # if self.noise_overlap_threshold is not None:
        #    sorting=ml_ms4alg.mountainsort4_curation(
        #      recording=recording,
        #      sorting=sorting,
        #      noise_overlap_threshold=self.noise_overlap_threshold
        #    )        

        SFMdaSortingExtractor.write_sorting(
            sorting=sorting, save_path=self.firings_out)
Exemplo n.º 8
0
    def run(self):
        # This temporary file will automatically be removed even in the case of a python exception
        with TemporaryDirectory() as tmpdir:
            # names of files for the temporary/intermediate data
            filt = tmpdir + '/filt.mda'
            filt2 = tmpdir + '/filt2.mda'
            pre = tmpdir + '/pre.mda'

            print('Bandpass filtering raw -> filt...')
            _bandpass_filter(self.recording_file_in, filt)

            if self.mask_out_artifacts:
                print('Masking out artifacts filt -> filt2...')
                _mask_out_artifacts(filt, filt2)
            else:
                print('Copying filt -> filt2...')
                filt2 = filt

            if self.whiten:
                print('Whitening filt2 -> pre...')
                _whiten(filt2, pre)
            else:
                pre = filt2

            # read the preprocessed timeseries into RAM (maybe we'll do it differently later)
            X = sf.mdaio.readmda(pre)

            # handle the geom
            if type(self.geom_in) == str:
                print('Using geom.csv from a file', self.geom_in)
                geom = read_geom_csv(self.geom_in)
            else:
                # no geom file was provided as input
                num_channels = X.shape[0]
                if num_channels > 6:
                    raise Exception(
                        'For more than six channels, we require that a geom.csv be provided')
                # otherwise make a trivial geometry file
                print('Making a trivial geom file.')
                geom = np.zeros((X.shape[0], 2))

            # Now represent the preprocessed recording using a RecordingExtractor
            recording = se.NumpyRecordingExtractor(
                X, samplerate=30000, geom=geom)

            # hard-code this for now -- idea: run many simultaneous jobs, each using only 2 cores
            # important to set certain environment variables in the .sh script that calls this .py script
            num_workers = 2

            # Call MountainSort4
            sorting = ml_ms4alg.mountainsort4(
                recording=recording,
                detect_sign=self.detect_sign,
                adjacency_radius=self.adjacency_radius,
                clip_size=self.clip_size,
                detect_threshold=self.detect_threshold,
                detect_interval=self.detect_interval,
                num_workers=num_workers,
            )

            # Write the firings.mda
            print('Writing firings.mda...')
            sf.SFMdaSortingExtractor.write_sorting(
                sorting=sorting, save_path=self.firings_out)

            print('Computing cluster metrics...')
            cluster_metrics_path = tmpdir + '/cluster_metrics.json'
            _cluster_metrics(pre, self.firings_out, cluster_metrics_path)

            print('Computing isolation metrics...')
            isolation_metrics_path = tmpdir + '/isolation_metrics.json'
            pair_metrics_path = tmpdir + '/pair_metrics.json'
            _isolation_metrics(pre, self.firings_out,
                               isolation_metrics_path, pair_metrics_path)

            print('Combining metrics...')
            metrics_path = tmpdir + '/metrics.json'
            _combine_metrics(cluster_metrics_path,
                             isolation_metrics_path, metrics_path)

            # copy metrics.json to the output location
            shutil.copy(metrics_path, self.metrics_out)

            print('Creating label map...')
            label_map_path = tmpdir + '/label_map.mda'
            create_label_map(metrics=metrics_path,
                             label_map_out=label_map_path)

            print('Applying label map...')
            apply_label_map(firings=self.firings_out, label_map=label_map_path,
                            firings_out=self.firings_curated_out)