コード例 #1
0
class KiloSort(mlpr.Processor):
    NAME = 'KiloSort'
    VERSION = '0.2.0'  # wrapper VERSION
    ADDITIONAL_FILES = ['*.m']
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS']
    CONTAINER = None
    CONTAINER_SHARE_ID = None

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    channels = mlpr.IntegerListParameter(
        description='List of channels to use.', optional=True, default=[])
    firings_out = mlpr.Output('Output firings file')

    detect_sign = mlpr.IntegerParameter(
        'Use -1 or 1, depending on the sign of the spikes in the recording')
    adjacency_radius = mlpr.FloatParameter(
        'Use -1 to include all channels in every neighborhood')
    detect_threshold = mlpr.FloatParameter(
        optional=True, default=3, description='')
    # prm_template_name=mlpr.StringParameter(optional=False,description='TODO')
    freq_min = mlpr.FloatParameter(
        optional=True, default=300, description='Use 0 for no bandpass filtering')
    freq_max = mlpr.FloatParameter(
        optional=True, default=6000, description='Use 0 for no bandpass filtering')
    merge_thresh = mlpr.FloatParameter(
        optional=True, default=0.98, description='TODO')
    pc_per_chan = mlpr.IntegerParameter(
        optional=True, default=3, description='TODO')

    def run(self):
        code = ''.join(random.choice(string.ascii_uppercase)
                       for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp')+'/kilosort-tmp-'+code

        try:
            recording = se.MdaRecordingExtractor(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting = kilosort_helper(
                recording=recording,
                tmpdir=tmpdir,
                detect_sign=self.detect_sign,
                adjacency_radius=self.adjacency_radius,
                detect_threshold=self.detect_threshold,
                merge_thresh=self.merge_thresh,
                freq_min=self.freq_min,
                freq_max=self.freq_max,
                pc_per_chan=self.pc_per_chan
            )
            se.MdaSortingExtractor.writeSorting(
                sorting=sorting, save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                shutil.rmtree(tmpdir)
            raise
        shutil.rmtree(tmpdir)
コード例 #2
0
class MountainSort4(mlpr.Processor):
    NAME = 'MountainSort4'
    VERSION = '4.0.1'

    dataset_dir = mlpr.Input('Directory of dataset', directory=True)
    firings_out = mlpr.Output('Output firings file')

    detect_sign = mlpr.IntegerParameter(
        'Use -1, 0, or 1, depending on the sign of the spikes in the recording'
    )
    adjacency_radius = mlpr.FloatParameter(
        'Use -1 to include all channels in every neighborhood')
    freq_min = mlpr.FloatParameter(
        optional=True,
        default=300,
        description='Use 0 for no bandpass filtering')
    freq_max = mlpr.FloatParameter(
        optional=True,
        default=6000,
        description='Use 0 for no bandpass filtering')
    whiten = mlpr.BoolParameter(
        optional=True,
        default=True,
        description='Whether to do channel whitening as part of preprocessing')
    clip_size = mlpr.IntegerParameter(optional=True,
                                      default=50,
                                      description='')
    detect_threshold = mlpr.FloatParameter(optional=True,
                                           default=3,
                                           description='')
    detect_interval = mlpr.IntegerParameter(
        optional=True,
        default=10,
        description=
        'Minimum number of timepoints between events detected on the same channel'
    )
    noise_overlap_threshold = mlpr.FloatParameter(
        optional=True,
        default=0.15,
        description='Use None for no automated curation')

    def run(self):
        recording = si.MdaRecordingExtractor(self.dataset_dir)
        num_workers = int(os.environ.get('NUM_WORKERS', -1))
        if num_workers <= 0: num_workers = None
        sorting = sf.sorters.mountainsort4(
            recording=recording,
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            freq_min=self.freq_min,
            freq_max=self.freq_max,
            whiten=self.whiten,
            clip_size=self.clip_size,
            detect_threshold=self.detect_threshold,
            detect_interval=self.detect_interval,
            noise_overlap_threshold=self.noise_overlap_threshold,
            num_workers=num_workers)
        si.MdaSortingExtractor.writeSorting(sorting=sorting,
                                            save_path=self.firings_out)
コード例 #3
0
class GenerateMearecRecording(mlpr.Processor):
    NAME = "GenerateMearecRecording"
    VERSION = "0.1.0"

    # input file
    templates_in = mlpr.Input(description='.h5 file containing templates')

    # output file
    recording_out = mlpr.Output()

    # recordings params
    drifting = mlpr.BoolParameter()
    noise_level = mlpr.FloatParameter()
    bursting = mlpr.BoolParameter()
    shape_mod = mlpr.BoolParameter()

    # spiketrains params
    duration = mlpr.FloatParameter()
    n_exc = mlpr.IntegerParameter()
    n_inh = mlpr.IntegerParameter()

    # templates params
    min_dist = mlpr.FloatParameter()

    # seed
    seed = mlpr.IntegerParameter()

    def run(self):
        recordings_params = deepcopy(mr.get_default_recordings_params())

        recordings_params['recordings']['drifting'] = self.drifting
        recordings_params['recordings']['noise_level'] = self.noise_level
        recordings_params['recordings']['bursting'] = self.bursting
        recordings_params['recordings']['shape_mod'] = self.shape_mod
        recordings_params['recordings']['seed'] = self.seed
        # recordings_params['recordings']['chunk_conv_duration'] = 0  # turn off parallel execution

        recordings_params['spiketrains']['duration'] = self.duration
        recordings_params['spiketrains']['n_exc'] = self.n_exc
        recordings_params['spiketrains']['n_inh'] = self.n_inh
        recordings_params['spiketrains']['seed'] = self.seed

        recordings_params['templates']['min_dist'] = self.min_dist
        recordings_params['templates']['seed'] = self.seed

        # this is needed because mr.load_templates requires the file extension
        templates_fname = self.templates_in + '.h5'
        shutil.copyfile(self.templates_in, templates_fname)
        tempgen = mr.load_templates(Path(templates_fname))

        recgen = mr.gen_recordings(params=recordings_params,
                                   tempgen=tempgen,
                                   verbose=False)
        mr.save_recording_generator(recgen, self.recording_out)
        del recgen
コード例 #4
0
class SpykingCircus(mlpr.Processor):
    NAME='SpykingCircus'
    VERSION='0.1.2'
    
    recording_dir=mlpr.Input('Directory of recording',directory=True)
    channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[])
    firings_out=mlpr.Output('Output firings file')
    
    detect_sign=mlpr.IntegerParameter(description='-1, 1, or 0')
    adjacency_radius=mlpr.FloatParameter(optional=True,default=100,description='Channel neighborhood adjacency radius corresponding to geom file')
    spike_thresh=mlpr.FloatParameter(optional=True,default=6,description='Threshold for detection')
    template_width_ms=mlpr.FloatParameter(optional=True,default=3,description='Spyking circus parameter')
    filter=mlpr.BoolParameter(optional=True,default=True)
    whitening_max_elts=mlpr.IntegerParameter(optional=True,default=1000,description='I believe it relates to subsampling and affects compute time')
    clustering_max_elts=mlpr.IntegerParameter(optional=True,default=10000,description='I believe it relates to subsampling and affects compute time')
    
    def run(self):
        code=''.join(random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir=os.environ.get('TEMPDIR','/tmp')+'/ironclust-tmp-'+code
        
        num_workers=os.environ.get('NUM_WORKERS',2)
            
        try:
            recording=si.MdaRecordingExtractor(self.recording_dir)
            if len(self.channels)>0:
              recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting=sf.sorters.spyking_circus(
                recording=recording,
                output_folder=tmpdir,
                probe_file=None,
                file_name=None,
                detect_sign=self.detect_sign,
                adjacency_radius=self.adjacency_radius,
                spike_thresh=self.spike_thresh,
                template_width_ms=self.template_width_ms,
                filter=self.filter,
                merge_spikes=True,
                n_cores=num_workers,
                electrode_dimensions=None,
                whitening_max_elts=self.whitening_max_elts,
                clustering_max_elts=self.clustering_max_elts
            )
            si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                shutil.rmtree(tmpdir)
            raise
        shutil.rmtree(tmpdir)
コード例 #5
0
class IronClust(mlpr.Processor):
    NAME='IronClust'
    VERSION='4.2.6'
    
    recording_dir=mlpr.Input('Directory of recording',directory=True)
    channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[])
    firings_out=mlpr.Output('Output firings file')
    
    detect_sign=mlpr.IntegerParameter('Use -1, 0, or 1, depending on the sign of the spikes in the recording')
    adjacency_radius=mlpr.FloatParameter('Use -1 to include all channels in every neighborhood')
    detect_threshold=mlpr.FloatParameter(optional=True,default=3,description='')
    prm_template_name=mlpr.StringParameter(optional=False,description='TODO')
    freq_min=mlpr.FloatParameter(optional=True,default=300,description='Use 0 for no bandpass filtering')
    freq_max=mlpr.FloatParameter(optional=True,default=6000,description='Use 0 for no bandpass filtering')
    merge_thresh=mlpr.FloatParameter(optional=True,default=0.98,description='TODO')
    pc_per_chan=mlpr.IntegerParameter(optional=True,default=3,description='TODO')
    
    def run(self):
        ironclust_src=os.environ.get('IRONCLUST_SRC',None)
        if not ironclust_src:
            raise Exception('Environment variable not set: IRONCLUST_SRC')
        code=''.join(random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir=os.environ.get('TEMPDIR','/tmp')+'/ironclust-tmp-'+code
            
        try:
            recording=si.MdaRecordingExtractor(self.recording_dir)
            if len(self.channels)>0:
              recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting=sf.sorters.ironclust(
                recording=recording,
                tmpdir=tmpdir, ## TODO
                detect_sign=self.detect_sign,
                adjacency_radius=self.adjacency_radius,
                detect_threshold=self.detect_threshold,
                merge_thresh=self.merge_thresh,
                freq_min=self.freq_min,
                freq_max=self.freq_max,
                pc_per_chan=self.pc_per_chan,
                prm_template_name=self.prm_template_name,
                ironclust_src=ironclust_src
            )
            si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                shutil.rmtree(tmpdir)
            raise
        shutil.rmtree(tmpdir)
コード例 #6
0
class ComputeMandelbrotWithError(mlpr.Processor):
    NAME = 'ComputeMandelbrotWithError'
    VERSION = '0.1.4'

    xmin = mlpr.IntegerParameter('The minimum x value',
                                 optional=True,
                                 default=-2)
    xmax = mlpr.IntegerParameter('The maximum x value',
                                 optional=True,
                                 default=0.5)
    ymin = mlpr.IntegerParameter('The minimum y value',
                                 optional=True,
                                 default=-1.25)
    ymax = mlpr.IntegerParameter('The maximum y value',
                                 optional=True,
                                 default=1.25)
    num_x = mlpr.IntegerParameter(
        'The number of points (resolution) in the x dimension',
        optional=True,
        default=1000)
    num_iter = mlpr.IntegerParameter('Number of iterations',
                                     optional=True,
                                     default=1000)
    subsampling_factor = mlpr.IntegerParameter(
        'Subsampling factor (1 means no subsampling)',
        optional=True,
        default=1)
    subsampling_offset = mlpr.IntegerParameter('Subsampling offset',
                                               optional=True,
                                               default=0)
    throw_error = mlpr.BoolParameter(
        'Whether to intentionally throw an error for testing purposes',
        optional=True,
        default=False)

    output_npy = mlpr.Output('The output .npy file.')

    def __init__(self):
        mlpr.Processor.__init__(self)

    def run(self):
        import time
        if self.throw_error:
            print('Intentionally throwing error in 2 seconds...')
            time.sleep(2)
            raise Exception('Intentionally throwing error.')
        if self.subsampling_factor > 1:
            print('Using subsampling factor {}, offset {}'.format(
                self.subsampling_factor, self.subsampling_offset))
        X = compute_mandelbrot(xmin=self.xmin,
                               xmax=self.xmax,
                               ymin=self.ymin,
                               ymax=self.ymax,
                               num_x=self.num_x,
                               num_iter=self.num_iter,
                               subsampling_factor=self.subsampling_factor,
                               subsampling_offset=self.subsampling_offset)
        np.save(self.output_npy, X)
コード例 #7
0
ファイル: computeunitdetail.py プロジェクト: yger/spikeforest
class ComputeUnitDetail(mlpr.Processor):
    NAME = 'ComputeUnitDetail'
    VERSION = '0.1.0'
    CONTAINER = None

    recording_dir = mlpr.Input(description='Recording directory',
                               optional=False,
                               directory=True)
    firings = mlpr.Input(description='Input firings.mda file')
    unit_id = mlpr.IntegerParameter(description='Unit ID')
    json_out = mlpr.Output(description='Output .json file')

    def run(self):
        recording = SFMdaRecordingExtractor(
            dataset_directory=self.recording_directory, download=True)
        sorting = SFMdaSortingExtractor(firings_file=self.firings)
        waveforms0 = _get_random_spike_waveforms(recording=recording,
                                                 sorting=sorting,
                                                 unit=self.unit_id)
        channel_ids = recording.get_channel_ids()
        avg_waveform = np.median(waveforms0, axis=2)
        ret = dict(channel_ids=channel_ids,
                   average_waveform=avg_waveform.tolist())
        with open(self.json_out, 'w') as f:
            json.dump(ret, f)
コード例 #8
0
class CreateEfficientAccessRecordingFile(mlpr.Processor):
    NAME = 'CreateEfficientAccessRecordingFile'
    VERSION = '0.1.2'
    recording = mlpr.Input()
    segment_size = mlpr.IntegerParameter(optional=True, default=1000000)
    hdf5_out = mlpr.Output()

    def run(self):
        import h5py
        recording = self.recording
        segment_size = self.segment_size
        channel_ids = recording.get_channel_ids()
        samplerate = recording.get_sampling_frequency()
        M = len(channel_ids)  # number of channels
        N = recording.get_num_frames()  # Number of timepoints
        num_segments = int(np.ceil(N / segment_size))

        try:
            channel_locations = recording.get_channel_locations(
                channel_ids=channel_ids)
            nd = len(channel_locations[0])
            geom = np.zeros((M, nd))
            for m in range(M):
                geom[m, :] = channel_locations[m]
        except:
            nd = 2
            geom = np.zeros((M, nd))

        with h5py.File(self.hdf5_out, "w") as f:
            f.create_dataset('segment_size', data=[segment_size])
            f.create_dataset('num_segments', data=[num_segments])
            f.create_dataset('num_channels', data=[M])
            f.create_dataset('channel_ids', data=np.array(channel_ids))
            f.create_dataset('num_timepoints', data=[N])
            f.create_dataset('samplerate', data=[samplerate])
            f.create_dataset('geom', data=geom)
            if callable(recording.hash):
                hash0 = recording.hash()
            else:
                hash0 = recording.hash
            f.create_dataset('recording_hash', data=np.array([hash0.encode()]))
            for j in range(num_segments):
                segment = np.zeros((M, segment_size),
                                   dtype=float)  # fix dtype here
                t1 = int(j * segment_size)  # first timepoint of the segment
                t2 = int(np.minimum(
                    N, (t1 + segment_size)))  # last timepoint of segment (+1)
                s1 = int(np.maximum(0, t1))  # first timepoint
                s2 = int(np.minimum(N, t2))  # last timepoint (+1)

                # determine aa so that t1-s1+aa = 0
                # so, aa = -(t1-s1)
                aa = -(t1 - s1)
                segment[:, aa:aa + s2 - s1] = recording.get_traces(
                    start_frame=s1, end_frame=s2)  # Read the segment

                for ii, ch in enumerate(channel_ids):
                    f.create_dataset('part-{}-{}'.format(ch, j),
                                     data=segment[ii, :].ravel())
コード例 #9
0
class YASS(mlpr.Processor):
    NAME = 'YASS'
    VERSION = '0.1.0'
    # used by container to pass the env variables
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS']
    ADDITIONAL_FILES = ['*.yaml']
    CONTAINER = 'sha1://087767605e10761331699dda29519444bbd823f4/02-12-2019/yass.simg'
    CONTAINER_SHARE_ID = '69432e9201d0'  # place to look for container

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    channels = mlpr.IntegerListParameter(
        description='List of channels to use.', optional=True, default=[])
    firings_out = mlpr.Output('Output firings file')
    #paramfile_out = mlpr.Output('YASS yaml config file')

    detect_sign = mlpr.IntegerParameter(description='-1, 1, or 0')
    adjacency_radius = mlpr.FloatParameter(
        optional=True, default=100, description='Channel neighborhood adjacency radius corresponding to geom file')
    template_width_ms = mlpr.FloatParameter(
        optional=True, default=3, description='Spike width in milliseconds')
    filter = mlpr.BoolParameter(optional=True, default=True)

    def run(self):
        code = ''.join(random.choice(string.ascii_uppercase)
                       for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp')+'/yass-tmp-'+code

        #num_workers = os.environ.get('NUM_WORKERS', 1)
        #print('num_workers: {}'.format(num_workers))
        try:
            recording = se.MdaRecordingExtractor(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting, yaml_file = yass_helper(
                recording=recording,
                output_folder=tmpdir,
                probe_file=None,
                file_name=None,
                detect_sign=self.detect_sign,
                adjacency_radius=self.adjacency_radius,
                template_width_ms=self.template_width_ms,
                filter=self.filter)
            se.MdaSortingExtractor.writeSorting(
                sorting=sorting, save_path=self.firings_out)
            #shutil.copyfile(yaml_file, self.paramfile_out)
        except:
            if os.path.exists(tmpdir):
                shutil.rmtree(tmpdir)
            raise
        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
コード例 #10
0
class CreateSpikeSprays(mlpr.Processor):
    NAME = 'CreateSpikeSprays'
    VERSION = '0.1.0'
    CONTAINER = _CONTAINER

    recording_directory = mlpr.Input(description='Recording directory',
                                     optional=False,
                                     directory=True)
    filtered_timeseries = mlpr.Input(
        description='Filtered timeseries file (.mda)', optional=False)
    firings_true = mlpr.Input(description='True firings -- firings_true.mda',
                              optional=False)
    firings_sorted = mlpr.Input(description='Sorted firings -- firings.mda',
                                optional=False)
    unit_id_true = mlpr.IntegerParameter(description='ID of the true unit')
    unit_id_sorted = mlpr.IntegerParameter(description='ID of the sorted unit')
    neighborhood_size = mlpr.IntegerParameter(
        description='Max size of the electrode neighborhood',
        optional=True,
        default=7)
    num_spikes = mlpr.IntegerParameter(
        description='Max number of spikes in the spike spray',
        optional=True,
        default=20)
    json_out = mlpr.Output(description='Output json object')

    def run(self):
        rx = SFMdaRecordingExtractor(
            dataset_directory=self.recording_directory,
            download=True,
            raw_fname=self.filtered_timeseries)
        sx_true = SFMdaSortingExtractor(firings_file=self.firings_true)
        sx = SFMdaSortingExtractor(firings_file=self.firings_sorted)
        ssobj = create_spikesprays(rx=rx,
                                   sx_true=sx_true,
                                   sx_sorted=sx,
                                   neighborhood_size=self.neighborhood_size,
                                   num_spikes=self.num_spikes,
                                   unit_id_true=self.unit_id_true,
                                   unit_id_sorted=self.unit_id_sorted)
        with open(self.json_out, 'w') as f:
            json.dump(ssobj, f)
コード例 #11
0
class ComputeMandelbrot(mlpr.Processor):
    NAME = 'ComputeMandelbrot'
    VERSION = '0.1.4'

    xmin = mlpr.IntegerParameter('The minimum x value',
                                 optional=True,
                                 default=-2)
    xmax = mlpr.IntegerParameter('The maximum x value',
                                 optional=True,
                                 default=0.5)
    ymin = mlpr.IntegerParameter('The minimum y value',
                                 optional=True,
                                 default=-1.25)
    ymax = mlpr.IntegerParameter('The maximum y value',
                                 optional=True,
                                 default=1.25)
    num_x = mlpr.IntegerParameter(
        'The number of points (resolution) in the x dimension',
        optional=True,
        default=1000)
    num_iter = mlpr.IntegerParameter('Number of iterations',
                                     optional=True,
                                     default=1000)
    subsampling_factor = mlpr.IntegerParameter(
        'Subsampling factor (1 means no subsampling)',
        optional=True,
        default=1)
    subsampling_offset = mlpr.IntegerParameter('Subsampling offset',
                                               optional=True,
                                               default=0)

    output_npy = mlpr.Output('The output .npy file.')

    def __init__(self):
        mlpr.Processor.__init__(self)

    def run(self):
        print('=== ComputeMandelbrot ===', self.subsampling_factor,
              self.subsampling_offset)
        if self.subsampling_factor > 1:
            print('Using subsampling factor {}, offset {}'.format(
                self.subsampling_factor, self.subsampling_offset))
        X = compute_mandelbrot(xmin=self.xmin,
                               xmax=self.xmax,
                               ymin=self.ymin,
                               ymax=self.ymax,
                               num_x=self.num_x,
                               num_iter=self.num_iter,
                               subsampling_factor=self.subsampling_factor,
                               subsampling_offset=self.subsampling_offset)
        np.save(self.output_npy, X)
コード例 #12
0
class ComputeNPrimes(mlpr.Processor):
    NAME = 'ComputeNPrimes'
    VERSION = '0.1.3'

    n = mlpr.IntegerParameter('The integer n.')
    output = mlpr.Output('The output .npy file.')

    def __init__(self):
        mlpr.Processor.__init__(self)

    def run(self):
        primes = compute_n_primes(self.n)
        print('Prime {}: {}'.format(self.n, primes[-1]))
        np.save(self.output, primes)
コード例 #13
0
class ComputeNthPrime(mlpr.Processor):
    NAME = 'ComputeNthPrime'
    VERSION = '0.1.1'

    n = mlpr.IntegerParameter('The integer n.')
    output = mlpr.Output('The output text file.')

    def __init__(self):
        mlpr.Processor.__init__(self)

    def run(self):
        prime = nth_prime_number(self.n)
        with open(self.output, 'w') as f:
            f.write('{}'.format(prime))
コード例 #14
0
class CombineSubsampledMandelbrot(mlpr.Processor):
    NAME = 'CombineSubsampledMandelbrot'
    VERSION = '0.1.1'

    X_list = mlpr.Input(multi=True)
    X_out = mlpr.Output()
    num_x = mlpr.IntegerParameter()

    def run(self):
        print('=== CombineSubsampledMandelbrot ===', self.num_x)
        self.X_list
        arrays = []
        for X0 in self.X_list:
            arrays.append(np.load(X0))
        X = combine_subsampled_mandelbrot(arrays)
        X = X[:self.num_x, :]
        np.save(self.X_out, X)
コード例 #15
0
class ComputeAutocorrelograms(mlpr.Processor):
    NAME = 'ComputeAutocorrelograms'
    VERSION = '0.1.4'

    # Inputs
    firings_path = mlpr.Input()

    # Parameters
    samplerate = mlpr.FloatParameter()
    max_samples = mlpr.IntegerParameter(optional=True, default=100000)
    bin_size_msec = mlpr.FloatParameter(optional=True, default=2)
    max_dt_msec = mlpr.FloatParameter(optional=True, default=50)

    # Outputs
    json_out = mlpr.Output()

    def run(self):
        from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor

        print('test1', self.firings_path, self.samplerate)

        sorting = SFMdaSortingExtractor(firings_file=self.firings_path)
        samplerate = self.samplerate
        max_samples = self.max_samples
        max_dt_msec = self.max_dt_msec
        bin_size_msec = self.bin_size_msec

        max_dt_tp = max_dt_msec * samplerate / 1000
        bin_size_tp = bin_size_msec * samplerate / 1000

        autocorrelograms = []
        for unit_id in sorting.get_unit_ids():
            print('Unit::g {}'.format(unit_id))
            (bin_counts, bin_edges) = compute_autocorrelogram(sorting.get_unit_spike_train(unit_id), max_dt_tp=max_dt_tp, bin_size_tp=bin_size_tp, max_samples=max_samples)
            autocorrelograms.append(dict(
                unit_id=unit_id,
                bin_counts=bin_counts,
                bin_edges=bin_edges
            ))
        ret = dict(
            autocorrelograms=autocorrelograms
        )
        with open(self.json_out, 'w') as f:
            json.dump(serialize_np(ret), f)
コード例 #16
0
ファイル: unitdetailview.py プロジェクト: yger/spikeforest
class ComputeUnitDetail(mlpr.Processor):
    NAME = 'ComputeUnitsDetail'
    VERSION = '0.1.5'
    recording = mlpr.Input()
    sorting = mlpr.Input()
    unit_id = mlpr.IntegerParameter()
    output = mlpr.Output()

    def run(self):
        event_times = self.sorting.get_unit_spike_train(unit_id=self.unit_id)  # pylint: disable=no-member
        snippets = self.recording.get_snippets(reference_frames=event_times,
                                               snippet_len=100)  # pylint: disable=no-member
        template = np.median(np.stack(snippets), axis=0)
        result0 = dict(unit_id=self.unit_id,
                       num_events=len(event_times),
                       event_times=event_times,
                       snippets=snippets,
                       template=template)
        with open(self.output, 'wb') as f:
            pickle.dump(result0, f)
コード例 #17
0
class ComputeAutocorrelograms(mlpr.Processor):
    NAME = 'ComputeAutocorrelograms'
    VERSION = '0.1.5'

    # Inputs
    sorting = mlpr.Input()

    # Parameters
    max_samples = mlpr.IntegerParameter(optional=True, default=100000)
    bin_size_msec = mlpr.FloatParameter(optional=True, default=2)
    max_dt_msec = mlpr.FloatParameter(optional=True, default=50)

    # Outputs
    json_out = mlpr.Output()

    def run(self):
        sorting = self.sorting
        samplerate = sorting.get_sampling_frequency()
        max_samples = self.max_samples
        max_dt_msec = self.max_dt_msec
        bin_size_msec = self.bin_size_msec

        max_dt_tp = max_dt_msec * samplerate / 1000
        bin_size_tp = bin_size_msec * samplerate / 1000

        autocorrelograms = []
        for unit_id in sorting.get_unit_ids():
            (bin_counts, bin_edges) = compute_autocorrelogram(
                sorting.get_unit_spike_train(unit_id=unit_id),
                max_dt_tp=max_dt_tp,
                bin_size_tp=bin_size_tp,
                max_samples=max_samples)
            bin_edges = bin_edges / samplerate * 1000  # milliseconds
            autocorrelograms.append(
                dict(unit_id=unit_id,
                     bin_counts=bin_counts,
                     bin_edges=bin_edges))
        ret = dict(autocorrelograms=autocorrelograms)
        with open(self.json_out, 'w') as f:
            simplejson.dump(serialize_np(ret), f, ignore_nan=True)
コード例 #18
0
ファイル: tridesclous.py プロジェクト: yger/spikeforest
class TridesclousOld(mlpr.Processor):
    """
    tridesclous is one of the more convenient, fast and elegant
    spike sorters.

    Installation instruction
        >>> pip install https://github.com/tridesclous/tridesclous/archive/master.zip

    More information on tridesclous at:
    * https://github.com/tridesclous/tridesclous
    * https://tridesclous.readthedocs.io

    """

    NAME = 'Tridesclous'
    VERSION = '0.1.1'  # wrapper VERSION
    ADDITIONAL_FILES = []
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS',
        'OMP_NUM_THREADS', 'TEMPDIR'
    ]
    CONTAINER = 'sha1://9fb4a9350492ee84c8ea5d8692434ecba3cf33da/2019-05-13/tridesclous.simg'
    LOCAL_MODULES = ['../../spikeforest']

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    firings_out = mlpr.Output('Output firings file')
    channels = mlpr.IntegerListParameter(
        optional=True, default=[], description='List of channels to use.')
    detect_sign = mlpr.FloatParameter(optional=True,
                                      default=-1,
                                      description='')
    detection_threshold = mlpr.FloatParameter(optional=True,
                                              default=5.5,
                                              description='')
    freq_min = mlpr.FloatParameter(optional=True, default=400, description='')
    freq_max = mlpr.FloatParameter(optional=True, default=5000, description='')
    waveforms_n_left = mlpr.IntegerParameter(description='',
                                             optional=True,
                                             default=-45)
    waveforms_n_right = mlpr.IntegerParameter(description='',
                                              optional=True,
                                              default=60)
    align_waveform = mlpr.BoolParameter(description='',
                                        optional=True,
                                        default=False)
    common_ref_removal = mlpr.BoolParameter(description='',
                                            optional=True,
                                            default=False)
    peak_span = mlpr.FloatParameter(optional=True,
                                    default=.0002,
                                    description='')
    alien_value_threshold = mlpr.FloatParameter(optional=True,
                                                default=100,
                                                description='')

    def run(self):
        import tridesclous as tdc

        tmpdir = Path(_get_tmpdir('tdc'))
        recording = SFMdaRecordingExtractor(self.recording_dir)

        params = {
            'fullchain_kargs': {
                'duration': 300.,
                'preprocessor': {
                    'highpass_freq': self.freq_min,
                    'lowpass_freq': self.freq_max,
                    'smooth_size': 0,
                    'chunksize': 1024,
                    'lostfront_chunksize': 128,
                    'signalpreprocessor_engine': 'numpy',
                    'common_ref_removal': self.common_ref_removal,
                },
                'peak_detector': {
                    'peakdetector_engine': 'numpy',
                    'peak_sign': '-',
                    'relative_threshold': self.detection_threshold,
                    'peak_span': self.peak_span,
                },
                'noise_snippet': {
                    'nb_snippet': 300,
                },
                'extract_waveforms': {
                    'n_left': self.waveforms_n_left,
                    'n_right': self.waveforms_n_right,
                    'mode': 'rand',
                    'nb_max': 20000,
                    'align_waveform': self.align_waveform,
                },
                'clean_waveforms': {
                    'alien_value_threshold': self.alien_value_threshold,
                },
            },
            'feat_method': 'peak_max',
            'feat_kargs': {},
            'clust_method': 'sawchaincut',
            'clust_kargs': {
                'kde_bandwith': 1.
            },
        }

        # save prb file:
        probe_file = tmpdir / 'probe.prb'
        se.save_probe_file(recording, probe_file, format='spyking_circus')

        # source file
        if isinstance(recording,
                      se.BinDatRecordingExtractor) and recording._frame_first:
            # no need to copy
            raw_filename = recording._datfile
            dtype = recording._timeseries.dtype.str
            nb_chan = len(recording._channels)
            offset = recording._timeseries.offset
        else:
            # save binary file (chunk by hcunk) into a new file
            raw_filename = tmpdir / 'raw_signals.raw'
            n_chan = recording.get_num_channels()
            chunksize = 2**24 // n_chan
            se.write_binary_dat_format(recording,
                                       raw_filename,
                                       time_axis=0,
                                       dtype='float32',
                                       chunksize=chunksize)
            dtype = 'float32'
            offset = 0

        # initialize source and probe file
        tdc_dataio = tdc.DataIO(dirname=str(tmpdir))
        nb_chan = recording.get_num_channels()

        tdc_dataio.set_data_source(
            type='RawData',
            filenames=[str(raw_filename)],
            dtype=dtype,
            sample_rate=recording.get_sampling_frequency(),
            total_channel=nb_chan,
            offset=offset)
        tdc_dataio.set_probe_file(str(probe_file))

        try:
            sorting = tdc_helper(tmpdir=tmpdir,
                                 params=params,
                                 recording=recording)
            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not getattr(self, '_keep_temp_files', False):
                    shutil.rmtree(tmpdir)
            raise
        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
コード例 #19
0
class JRClust(mlpr.Processor):
    """
    JRClust v4 wrapper
      written by J. James Jun, May 8, 2019
      modified from `spikeforest/spikesorters/ironclust/ironclust.py`

    [Installation instruction in SpikeForest environment]
    1. Run `git clone https://github.com/JaneliaSciComp/JRCLUST`
    2. Activate conda environment for SpikeForest
    3. Create `JRCLUST_PATH` and `MDAIO_PATH`
    4. Flatiron execution: `module load matlab/R2019a cuda/10.0.130_410.48`
    DO NOT LOAD `module load gcc`. 
    DO NOT USE `matlab/R2018b` (it will crash while creating a parpool).

    See: James Jun, et al. Real-time spike sorting platform for
    high-density extracellular probes with ground-truth validation and drift correction
    https://github.com/JaneliaSciComp/JRCLUST
    """

    NAME = 'JRClust'
    VERSION = '0.1.3'
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS',
        'OMP_NUM_THREADS', 'TEMPDIR'
    ]
    ADDITIONAL_FILES = ['*.m', '*.prm']
    CONTAINER = None

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    channels = mlpr.IntegerListParameter(
        description='List of channels to use.', optional=True, default=[])
    firings_out = mlpr.Output('Output firings file')

    detect_sign = mlpr.IntegerParameter(
        optional=True,
        default=-1,
        description=
        'Use -1, 0, or 1, depending on the sign of the spikes in the recording'
    )
    adjacency_radius = mlpr.FloatParameter(optional=True,
                                           default=50,
                                           description='')
    detect_threshold = mlpr.FloatParameter(optional=True,
                                           default=4.5,
                                           description='detection threshold')
    freq_min = mlpr.FloatParameter(
        optional=True,
        default=300,
        description='Use 0 for no bandpass filtering')
    freq_max = mlpr.FloatParameter(
        optional=True,
        default=3000,
        description='Use 0 for no bandpass filtering')
    merge_thresh = mlpr.FloatParameter(
        optional=True,
        default=0.98,
        description='Threshold for automated merging')
    pc_per_chan = mlpr.IntegerParameter(
        optional=True,
        default=1,
        description='Number of principal components per channel')

    # added in version 0.2.4
    filter_type = mlpr.StringParameter(
        optional=True,
        default='bandpass',
        description='{none, bandpass, wiener, fftdiff, ndiff}')
    nDiffOrder = mlpr.FloatParameter(optional=True, default=2, description='')
    common_ref_type = mlpr.StringParameter(optional=True,
                                           default='none',
                                           description='{none, mean, median}')
    min_count = mlpr.IntegerParameter(optional=True,
                                      default=30,
                                      description='Minimum cluster size')
    fGpu = mlpr.IntegerParameter(optional=True,
                                 default=1,
                                 description='Use GPU if available')
    fParfor = mlpr.IntegerParameter(optional=True,
                                    default=0,
                                    description='Use parfor if available')
    feature_type = mlpr.StringParameter(
        optional=True,
        default='gpca',
        description='{gpca, pca, vpp, vmin, vminmax, cov, energy, xcov}')

    def run(self):
        tmpdir = _get_tmpdir('jrclust')

        try:
            recording = SFMdaRecordingExtractor(self.recording_dir)
            params = read_dataset_params(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)

            all_params = dict()
            for param0 in self.PARAMETERS:
                all_params[param0.name] = getattr(self, param0.name)

            sorting = jrclust_helper(
                recording=recording,
                tmpdir=tmpdir,
                params=params,
                **all_params,
            )
            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not getattr(self, '_keep_temp_files', False):
                    shutil.rmtree(tmpdir)
            raise
        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
コード例 #20
0
class SpykingCircus(mlpr.Processor):
    NAME = 'SpykingCircus'
    VERSION = '0.3.4'
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS',
        'OMP_NUM_THREADS'
    ]
    ADDITIONAL_FILES = ['*.params']
    # CONTAINER = 'sha1://8958530b960522d529163344af2faa09ea805716/2019-05-06/spyking_circus.simg'
    # CONTAINER = 'sha1://eed2314fbe2fb1cc7cfe0a36b4e205ffb94add1c/2019-06-17/spyking_circus.simg'
    # CONTAINER = 'sha1://68a175faef53e29af068b8b95649021593f9020a/2019-07-01/spyking_circus.simg'
    CONTAINER = 'sha1://5ca21c482edaf4b3b689f2af3c719a32567ba21e/2019-07-22/spyking_circus.simg'
    LOCAL_MODULES = ['../../spikeforest']

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    firings_out = mlpr.Output('Output firings file')

    detect_sign = mlpr.IntegerParameter(description='-1, 1, or 0')
    adjacency_radius = mlpr.FloatParameter(
        optional=True,
        default=200,
        description=
        'Channel neighborhood adjacency radius corresponding to geom file')
    detect_threshold = mlpr.FloatParameter(
        optional=True, default=6, description='Threshold for detection')
    template_width_ms = mlpr.FloatParameter(
        optional=True, default=3, description='Spyking circus parameter')
    filter = mlpr.BoolParameter(optional=True, default=True)
    whitening_max_elts = mlpr.IntegerParameter(
        optional=True,
        default=1000,
        description=
        'I believe it relates to subsampling and affects compute time')
    clustering_max_elts = mlpr.IntegerParameter(
        optional=True,
        default=10000,
        description=
        'I believe it relates to subsampling and affects compute time')

    def run(self):

        import spikesorters as sorters
        print('SpyKING CIRCUS......')
        recording = SFMdaRecordingExtractor(self.recording_dir)
        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/spyking-circus-' + code

        num_workers = int(os.environ.get('NUM_WORKERS', '1'))

        sorter = sorters.SpykingcircusSorter(recording=recording,
                                             output_folder=tmpdir,
                                             verbose=True,
                                             delete_output_folder=True)

        sorter.set_params(
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            detect_threshold=self.detect_threshold,
            template_width_ms=self.template_width_ms,
            filter=self.filter,
            merge_spikes=True,
            auto_merge=0.5,
            num_workers=num_workers,
            electrode_dimensions=None,
            whitening_max_elts=self.whitening_max_elts,
            clustering_max_elts=self.clustering_max_elts,
        )

        timer = sorter.run()
        print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))

        sorting = sorter.get_result()

        SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                            save_path=self.firings_out)
コード例 #21
0
ファイル: mountainsort4.py プロジェクト: yger/spikeforest
class MountainSort4Old(mlpr.Processor):
    NAME = 'MountainSort4'
    VERSION = '4.3.0'
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS']
    CONTAINER = 'sha1://e06fee7f72f6b66d80d899ebc08e7c39e5a2458e/2019-05-06/mountainsort4.simg'
    # CONTAINER = 'sha1://8743ff094a26bdedd16f36209a05333f1f82fbd8/2019-06-26/mountainsort4.simg'
    LOCAL_MODULES = ['../../spikeforest']

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    firings_out = mlpr.Output('Output firings file')

    detect_sign = mlpr.IntegerParameter(
        'Use -1, 0, or 1, depending on the sign of the spikes in the recording')
    adjacency_radius = mlpr.FloatParameter(
        'Use -1 to include all channels in every neighborhood')
    freq_min = mlpr.FloatParameter(
        optional=True, default=300, description='Use 0 for no bandpass filtering')
    freq_max = mlpr.FloatParameter(
        optional=True, default=6000, description='Use 0 for no bandpass filtering')
    whiten = mlpr.BoolParameter(optional=True, default=True,
                                description='Whether to do channel whitening as part of preprocessing')
    clip_size = mlpr.IntegerParameter(
        optional=True, default=50, description='')
    detect_threshold = mlpr.FloatParameter(
        optional=True, default=3, description='')
    detect_interval = mlpr.IntegerParameter(
        optional=True, default=10, description='Minimum number of timepoints between events detected on the same channel')
    noise_overlap_threshold = mlpr.FloatParameter(
        optional=True, default=0.15, description='Use None for no automated curation')

    def run(self):
        from .bandpass_filter import bandpass_filter
        from .whiten import whiten

        import ml_ms4alg

        print('MountainSort4......')
        recording = SFMdaRecordingExtractor(self.recording_dir)
        num_workers = os.environ.get('NUM_WORKERS', None)
        if num_workers:
            num_workers = int(num_workers)

        # Bandpass filter
        if self.freq_min or self.freq_max:
            recording = bandpass_filter(
                recording=recording, freq_min=self.freq_min, freq_max=self.freq_max)

        # Whiten
        if self.whiten:
            recording = whiten(recording=recording)

        # Sort
        sorting = ml_ms4alg.mountainsort4(
            recording=recording,
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            clip_size=self.clip_size,
            detect_threshold=self.detect_threshold,
            detect_interval=self.detect_interval,
            num_workers=num_workers
        )

        # Curate
        # if self.noise_overlap_threshold is not None:
        #    sorting=ml_ms4alg.mountainsort4_curation(
        #      recording=recording,
        #      sorting=sorting,
        #      noise_overlap_threshold=self.noise_overlap_threshold
        #    )        

        SFMdaSortingExtractor.write_sorting(
            sorting=sorting, save_path=self.firings_out)
コード例 #22
0
ファイル: kilosort2.py プロジェクト: samuelgarcia/spikeforest
class KiloSort2(mlpr.Processor):
    """
    [Prerequisite]
    1. MATLAB (Tested on R2019a)
    2. CUDA Toolkit v10.0
    """

    NAME = 'KiloSort2'
    VERSION = '0.4.4'  # wrapper VERSION
    CONTAINER: Union[str, None] = None

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    firings_out = mlpr.Output('Output firings file')

    detect_sign = mlpr.IntegerParameter(
        default=-1,
        optional=True,
        description=
        'Use -1 or 1, depending on the sign of the spikes in the recording')
    adjacency_radius = mlpr.FloatParameter(
        default=30, optional=True, description='The sigmaMask for kilosort2')
    detect_threshold = mlpr.FloatParameter(optional=True,
                                           default=6,
                                           description='')
    # prm_template_name=mlpr.StringParameter(optional=False,description='TODO')
    freq_min = mlpr.FloatParameter(
        optional=True,
        default=150,
        description='Use 0 for no bandpass filtering')
    pc_per_chan = mlpr.IntegerParameter(optional=True,
                                        default=3,
                                        description='TODO')
    minFR = mlpr.FloatParameter(
        default=1 / 50,
        optional=True,
        description=
        'minimum spike rate (Hz), if a cluster falls below this for too long it gets removed'
    )
    car = mlpr.BoolParameter(
        default=True,
        optional=True,
        description='whether to do common average referencing')

    @staticmethod
    def install():
        print('Auto-installing kilosort2.')
        return install_kilosort2(
            repo='https://github.com/MouseLand/Kilosort2',
            commit='5629125f072795b082245f4265b567d3540cbc2b')

    def run(self):

        import spikesorters as sorters
        print('Kilosort2......')

        try:
            kilosort2_path = KiloSort2.install()
        except:
            traceback.print_exc()
            raise Exception('Problem installing kilosort.')
        sorters.Kilosort2Sorter.set_kilosort2_path(kilosort2_path)

        recording = SFMdaRecordingExtractor(self.recording_dir)
        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort2-' + code

        sorter = sorters.Kilosort2Sorter(recording=recording,
                                         output_folder=tmpdir,
                                         debug=True,
                                         delete_output_folder=True)

        sorter.set_params(detect_threshold=self.detect_threshold,
                          car=self.car,
                          minFR=self.minFR,
                          electrode_dimensions=None,
                          freq_min=self.freq_min,
                          sigmaMask=self.adjacency_radius,
                          nPCs=self.pc_per_chan)

        timer = sorter.run()
        print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))

        sorting = sorter.get_result()

        SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                            save_path=self.firings_out)
コード例 #23
0
class SpykingCircus(mlpr.Processor):
    NAME = 'SpykingCircus'
    VERSION = '0.2.2'
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS',
        'OMP_NUM_THREADS'
    ]
    ADDITIONAL_FILES = ['*.params']
    CONTAINER = 'sha1://914becce45aec56a84dd1dd4bca4037b09c50373/02-12-2019/spyking_circus.simg'
    CONTAINER_SHARE_ID = '69432e9201d0'  # place to look for container

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    channels = mlpr.IntegerListParameter(
        description='List of channels to use.', optional=True, default=[])
    firings_out = mlpr.Output('Output firings file')

    detect_sign = mlpr.IntegerParameter(description='-1, 1, or 0')
    adjacency_radius = mlpr.FloatParameter(
        optional=True,
        default=100,
        description=
        'Channel neighborhood adjacency radius corresponding to geom file')
    spike_thresh = mlpr.FloatParameter(optional=True,
                                       default=6,
                                       description='Threshold for detection')
    template_width_ms = mlpr.FloatParameter(
        optional=True, default=3, description='Spyking circus parameter')
    filter = mlpr.BoolParameter(optional=True, default=True)
    whitening_max_elts = mlpr.IntegerParameter(
        optional=True,
        default=1000,
        description=
        'I believe it relates to subsampling and affects compute time')
    clustering_max_elts = mlpr.IntegerParameter(
        optional=True,
        default=10000,
        description=
        'I believe it relates to subsampling and affects compute time')

    def run(self):
        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR',
                                '/tmp') + '/spyking-circus-tmp-' + code

        num_workers = os.environ.get('NUM_WORKERS', 1)

        try:
            recording = se.MdaRecordingExtractor(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting = spyking_circus(
                recording=recording,
                output_folder=tmpdir,
                probe_file=None,
                file_name=None,
                detect_sign=self.detect_sign,
                adjacency_radius=self.adjacency_radius,
                spike_thresh=self.spike_thresh,
                template_width_ms=self.template_width_ms,
                filter=self.filter,
                merge_spikes=True,
                n_cores=num_workers,
                electrode_dimensions=None,
                whitening_max_elts=self.whitening_max_elts,
                clustering_max_elts=self.clustering_max_elts,
            )
            se.MdaSortingExtractor.writeSorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if not getattr(self, '_keep_temp_files', False):
                if os.path.exists(tmpdir):
                    shutil.rmtree(tmpdir)
            raise
        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
コード例 #24
0
ファイル: kilosort2.py プロジェクト: samuelgarcia/spikeforest
class KiloSort2Old(mlpr.Processor):
    """
    KiloSort2 wrapper for SpikeForest framework
      written by J. James Jun, May 7, 2019
      modified from `spiketoolkit/sorters/Kilosort`
      to be made compatible with SpikeForest

    [Prerequisite]
    1. MATLAB (Tested on R2018b)
    2. CUDA Toolkit v9.1

    [Installation instruction in SpikeForest environment]
    1. Run `git clone https://github.com/alexmorley/Kilosort2.git`
      Kilosort2 currently doesn't work on tetrodes and low-channel count probes (as of May 7, 2019).
      Clone from Alex Morley's repository that fixed these issues.
      Original Kilosort2 code can be obtained from `https://github.com/MouseLand/Kilosort2.git`
    2. (optional) If Alex Morley's latest version doesn't work with SpikeForest, run
        `git checkout 43cbbfff89b9c88cdeb147ffd4ac35bfde9c7956`
    3. In Matlab, run `CUDA/mexGPUall` to compile all CUDA codes
    4. Add `KILOSORT2_PATH=...` in your .bashrc file.
    """

    NAME = 'KiloSort2'
    VERSION = '0.3.3'  # wrapper VERSION
    ADDITIONAL_FILES = ['*.m']
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS',
        'OMP_NUM_THREADS'
    ]
    CONTAINER = None

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    channels = mlpr.IntegerListParameter(
        description='List of channels to use.', optional=True, default=[])
    firings_out = mlpr.Output('Output firings file')

    detect_sign = mlpr.IntegerParameter(
        default=-1,
        optional=True,
        description=
        'Use -1 or 1, depending on the sign of the spikes in the recording')
    adjacency_radius = mlpr.FloatParameter(
        default=30,
        optional=True,
        description='Use -1 to include all channels in every neighborhood')
    detect_threshold = mlpr.FloatParameter(optional=True,
                                           default=6,
                                           description='')
    # prm_template_name=mlpr.StringParameter(optional=False,description='TODO')
    freq_min = mlpr.FloatParameter(
        optional=True,
        default=150,
        description='Use 0 for no bandpass filtering')
    freq_max = mlpr.FloatParameter(
        optional=True,
        default=6000,
        description='Use 0 for no bandpass filtering')
    merge_thresh = mlpr.FloatParameter(optional=True,
                                       default=0.98,
                                       description='TODO')
    pc_per_chan = mlpr.IntegerParameter(optional=True,
                                        default=3,
                                        description='TODO')
    minFR = mlpr.FloatParameter(
        default=1 / 50,
        optional=True,
        description=
        'minimum spike rate (Hz), if a cluster falls below this for too long it gets removed'
    )

    @staticmethod
    def install():
        print('Auto-installing kilosort.')
        return install_kilosort2(
            repo='https://github.com/alexmorley/Kilosort2',
            commit='43cbbfff89b9c88cdeb147ffd4ac35bfde9c7956')

    def run(self):
        _keep_temp_files = True

        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort2-tmp-' + code

        try:
            recording = SFMdaRecordingExtractor(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting = kilosort2_helper(recording=recording,
                                       tmpdir=tmpdir,
                                       detect_sign=self.detect_sign,
                                       adjacency_radius=self.adjacency_radius,
                                       detect_threshold=self.detect_threshold,
                                       merge_thresh=self.merge_thresh,
                                       freq_min=self.freq_min,
                                       freq_max=self.freq_max,
                                       pc_per_chan=self.pc_per_chan,
                                       minFR=self.minFR)
            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not _keep_temp_files:
                    print('removing tmpdir1')
                    shutil.rmtree(tmpdir)
            raise
        if not _keep_temp_files:
            print('removing tmpdir2')
            shutil.rmtree(tmpdir)
コード例 #25
0
ファイル: mountainsort4.py プロジェクト: yger/spikeforest
class MountainSort4(mlpr.Processor):
    NAME = 'MountainSort4'
    VERSION = '4.3.1'
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS']
    # CONTAINER = 'sha1://e06fee7f72f6b66d80d899ebc08e7c39e5a2458e/2019-05-06/mountainsort4.simg'
    CONTAINER = 'sha1://8743ff094a26bdedd16f36209a05333f1f82fbd8/2019-06-26/mountainsort4.simg'
    LOCAL_MODULES = ['../../spikeforest']

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    firings_out = mlpr.Output('Output firings file')

    detect_sign = mlpr.IntegerParameter(
        'Use -1, 0, or 1, depending on the sign of the spikes in the recording')
    adjacency_radius = mlpr.FloatParameter(
        'Use -1 to include all channels in every neighborhood')
    freq_min = mlpr.FloatParameter(
        optional=True, default=300, description='Use 0 for no bandpass filtering')
    freq_max = mlpr.FloatParameter(
        optional=True, default=6000, description='Use 0 for no bandpass filtering')
    whiten = mlpr.BoolParameter(optional=True, default=True,
                                description='Whether to do channel whitening as part of preprocessing')
    clip_size = mlpr.IntegerParameter(
        optional=True, default=50, description='')
    detect_threshold = mlpr.FloatParameter(
        optional=True, default=3, description='')
    detect_interval = mlpr.IntegerParameter(
        optional=True, default=10, description='Minimum number of timepoints between events detected on the same channel')
    noise_overlap_threshold = mlpr.FloatParameter(
        optional=True, default=0.15, description='Use None for no automated curation')

    def run(self):
        # from spikeinterface/spikesorters
        import spikesorters as sorters

        print('MountainSort4......')
        recording = SFMdaRecordingExtractor(self.recording_dir)
        num_workers = os.environ.get('NUM_WORKERS', None)
        if num_workers:
            num_workers = int(num_workers)

        code = ''.join(random.choice(string.ascii_uppercase)
                       for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/mountainsort4-' + code

        sorter = sorters.Mountainsort4Sorter(
            recording=recording,
            output_folder=tmpdir,
            debug=True,
            delete_output_folder=True
        )

        sorter.set_params(
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            clip_size=self.clip_size,
            detect_threshold=self.detect_threshold,
            detect_interval=self.detect_interval,
            num_workers=num_workers,
            curation=False,
            whiten=True,
            filter=True,
            freq_min=self.freq_min,
            freq_max=self.freq_max
        )

        # TODO: get elapsed time from the return of this run
        sorter.run()

        sorting = sorter.get_result()

        SFMdaSortingExtractor.write_sorting(
            sorting=sorting, save_path=self.firings_out)
コード例 #26
0
class CustomSorting(mlpr.Processor):
    NAME = 'CustomSorting'
    VERSION = '0.1.7'  # the version can be incremented when the code inside run() changes

    # input files
    recording_file_in = mlpr.Input('Path to raw.mda')
    geom_in = mlpr.Input('Path to geom.csv', optional=True)

    # output files
    firings_out = mlpr.Output('Output firings.mda file')
    firings_curated_out = mlpr.Output('Output firings.curated.mda file')
    metrics_out = mlpr.Output('Metrics .json output')

    # parameters
    samplerate = mlpr.FloatParameter("Sampling frequency")

    mask_out_artifacts = mlpr.BoolParameter(
        optional=True, default=False,
        description='Whether to mask out artifacts')
    freq_min = mlpr.FloatParameter(
        optional=True, default=300,
        description='Use 0 for no bandpass filtering')
    freq_max = mlpr.FloatParameter(
        optional=True, default=6000,
        description='Use 0 for no bandpass filtering')
    whiten = mlpr.BoolParameter(
        optional=True, default=True,
        description='Whether to do channel whitening as part of preprocessing')
    detect_sign = mlpr.IntegerParameter(
        'Use -1, 0, or 1, depending on the sign of the spikes in the recording')
    adjacency_radius = mlpr.FloatParameter(
        'Use -1 to include all channels in every neighborhood')
    clip_size = mlpr.IntegerParameter(
        optional=True, default=50, description='')
    detect_threshold = mlpr.FloatParameter(
        optional=True, default=3, description='')
    detect_interval = mlpr.IntegerParameter(
        optional=True, default=10, description='Minimum number of timepoints between events detected on the same channel')
    noise_overlap_threshold = mlpr.FloatParameter(
        optional=True, default=0.15, description='Use None for no automated curation')

    def run(self):
        # This temporary file will automatically be removed even in the case of a python exception
        with TemporaryDirectory() as tmpdir:
            # names of files for the temporary/intermediate data
            filt = tmpdir + '/filt.mda'
            filt2 = tmpdir + '/filt2.mda'
            pre = tmpdir + '/pre.mda'

            print('Bandpass filtering raw -> filt...')
            _bandpass_filter(self.recording_file_in, filt)

            if self.mask_out_artifacts:
                print('Masking out artifacts filt -> filt2...')
                _mask_out_artifacts(filt, filt2)
            else:
                print('Copying filt -> filt2...')
                filt2 = filt

            if self.whiten:
                print('Whitening filt2 -> pre...')
                _whiten(filt2, pre)
            else:
                pre = filt2

            # read the preprocessed timeseries into RAM (maybe we'll do it differently later)
            X = sf.mdaio.readmda(pre)

            # handle the geom
            if type(self.geom_in) == str:
                print('Using geom.csv from a file', self.geom_in)
                geom = read_geom_csv(self.geom_in)
            else:
                # no geom file was provided as input
                num_channels = X.shape[0]
                if num_channels > 6:
                    raise Exception(
                        'For more than six channels, we require that a geom.csv be provided')
                # otherwise make a trivial geometry file
                print('Making a trivial geom file.')
                geom = np.zeros((X.shape[0], 2))

            # Now represent the preprocessed recording using a RecordingExtractor
            recording = se.NumpyRecordingExtractor(
                X, samplerate=30000, geom=geom)

            # hard-code this for now -- idea: run many simultaneous jobs, each using only 2 cores
            # important to set certain environment variables in the .sh script that calls this .py script
            num_workers = 2

            # Call MountainSort4
            sorting = ml_ms4alg.mountainsort4(
                recording=recording,
                detect_sign=self.detect_sign,
                adjacency_radius=self.adjacency_radius,
                clip_size=self.clip_size,
                detect_threshold=self.detect_threshold,
                detect_interval=self.detect_interval,
                num_workers=num_workers,
            )

            # Write the firings.mda
            print('Writing firings.mda...')
            sf.SFMdaSortingExtractor.write_sorting(
                sorting=sorting, save_path=self.firings_out)

            print('Computing cluster metrics...')
            cluster_metrics_path = tmpdir + '/cluster_metrics.json'
            _cluster_metrics(pre, self.firings_out, cluster_metrics_path)

            print('Computing isolation metrics...')
            isolation_metrics_path = tmpdir + '/isolation_metrics.json'
            pair_metrics_path = tmpdir + '/pair_metrics.json'
            _isolation_metrics(pre, self.firings_out,
                               isolation_metrics_path, pair_metrics_path)

            print('Combining metrics...')
            metrics_path = tmpdir + '/metrics.json'
            _combine_metrics(cluster_metrics_path,
                             isolation_metrics_path, metrics_path)

            # copy metrics.json to the output location
            shutil.copy(metrics_path, self.metrics_out)

            print('Creating label map...')
            label_map_path = tmpdir + '/label_map.mda'
            create_label_map(metrics=metrics_path,
                             label_map_out=label_map_path)

            print('Applying label map...')
            apply_label_map(firings=self.firings_out, label_map=label_map_path,
                            firings_out=self.firings_curated_out)
コード例 #27
0
class IronClust(mlpr.Processor):
    NAME = 'IronClust'
    VERSION = '0.7.9'
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS', 'TEMPDIR']
    CONTAINER: Union[str, None] = None

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    firings_out = mlpr.Output('Output firings file')
    channels = mlpr.IntegerListParameter(
        optional=True, default=[],
        description='List of channels to use.'
    )
    detect_sign = mlpr.IntegerParameter(
        optional=True, default=-1,
        description='Use -1, 0, or 1, depending on the sign of the spikes in the recording'
    )
    adjacency_radius = mlpr.FloatParameter(
        optional=True, default=50,
        description='Use -1 to include all channels in every neighborhood'
    )
    adjacency_radius_out = mlpr.FloatParameter(
        optional=True, default=100,
        description='Use -1 to include all channels in every neighborhood'
    )
    detect_threshold = mlpr.FloatParameter(
        optional=True, default=4,
        description='detection threshold'
    )
    prm_template_name = mlpr.StringParameter(
        optional=True, default='',
        description='.prm template file name'
    )
    freq_min = mlpr.FloatParameter(
        optional=True, default=300,
        description='Use 0 for no bandpass filtering'
    )
    freq_max = mlpr.FloatParameter(
        optional=True, default=8000,
        description='Use 0 for no bandpass filtering'
    )
    merge_thresh = mlpr.FloatParameter(
        optional=True, default=0.98,
        description='Threshold for automated merging'
    )
    pc_per_chan = mlpr.IntegerParameter(
        optional=True, default=0,
        description='Number of principal components per channel'
    )  

    # added in version 0.2.4
    whiten = mlpr.BoolParameter(
        optional=True, default=False, description='Whether to do channel whitening as part of preprocessing')
    filter_type = mlpr.StringParameter(
        optional=True, default='bandpass', description='{none, bandpass, wiener, fftdiff, ndiff}')
    filter_detect_type = mlpr.StringParameter(
        optional=True, default='none', description='{none, bandpass, wiener, fftdiff, ndiff}')
    common_ref_type = mlpr.StringParameter(
        optional=True, default='mean', description='{none, mean, median}')
    batch_sec_drift = mlpr.FloatParameter(
        optional=True, default=300, description='batch duration in seconds. clustering time duration')
    step_sec_drift = mlpr.FloatParameter(
        optional=True, default=20, description='compute anatomical similarity every n sec')
    knn = mlpr.IntegerParameter(
        optional=True, default=30, description='K nearest neighbors')
    min_count = mlpr.IntegerParameter(
        optional=True, default=30, description='Minimum cluster size')
    fGpu = mlpr.BoolParameter(
        optional=True, default=True, description='Use GPU if available')
    fft_thresh = mlpr.FloatParameter(
        optional=True, default=8, description='FFT-based noise peak threshold')
    fft_thresh_low = mlpr.FloatParameter(
        optional=True, default=0, description='FFT-based noise peak lower threshold (set to 0 to disable dual thresholding scheme)')
    nSites_whiten = mlpr.IntegerParameter(
        optional=True, default=32, description='Number of adjacent channels to whiten')        
    feature_type = mlpr.StringParameter(
        optional=True, default='gpca', description='{gpca, pca, vpp, vmin, vminmax, cov, energy, xcov}')
    delta_cut = mlpr.FloatParameter(
        optional=True, default=1, description='Cluster detection threshold (delta-cutoff)')
    post_merge_mode = mlpr.IntegerParameter(
        optional=True, default=1, description='post_merge_mode')
    sort_mode = mlpr.IntegerParameter(
        optional=True, default=1, description='sort_mode')

    @staticmethod
    def install():
        print('Auto-installing ironclust.')
        return install_ironclust(commit='72bb6d097e0875d6cfe52bddf5f782e667e1b042')

    def run(self):
        timer = time.time()
        import spikesorters as sorters
        print('IronClust......')

        try:
            ironclust_path = IronClust.install()
        except:
            traceback.print_exc()
            raise Exception('Problem installing ironclust.')
        sorters.IronClustSorter.set_ironclust_path(ironclust_path)


        recording = SFMdaRecordingExtractor(self.recording_dir)
        code = ''.join(random.choice(string.ascii_uppercase)
                       for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/ironclust-' + code

        sorter = sorters.IronClustSorter(
            recording=recording,
            output_folder=tmpdir,
            debug=True,
            delete_output_folder = False # will be taken care by _keep_temp_files one step above
        )

        sorter.set_params(
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            adjacency_radius_out=self.adjacency_radius_out,
            detect_threshold=self.detect_threshold,
            prm_template_name=self.prm_template_name,
            freq_min=self.freq_min,
            freq_max=self.freq_max,
            merge_thresh=self.merge_thresh,
            pc_per_chan=self.pc_per_chan,
            whiten=self.whiten,
            filter_type=self.filter_type,
            filter_detect_type=self.filter_detect_type,
            common_ref_type=self.common_ref_type,
            batch_sec_drift=self.batch_sec_drift,
            step_sec_drift=self.step_sec_drift,
            knn=self.knn,
            min_count=self.min_count,
            fGpu=self.fGpu,
            fft_thresh=self.fft_thresh,
            fft_thresh_low=self.fft_thresh_low,
            nSites_whiten=self.nSites_whiten,
            feature_type=self.feature_type,
            delta_cut=self.delta_cut,
            post_merge_mode=self.post_merge_mode,
            sort_mode=self.sort_mode
        )     
        timer = sorter.run()
        print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
        sorting = sorter.get_result()

        SFMdaSortingExtractor.write_sorting(
            sorting=sorting, save_path=self.firings_out)
コード例 #28
0
class KiloSortOld(mlpr.Processor):
    """
    Kilosort wrapper for SpikeForest framework
      written by J. James Jun, May 21, 2019

    [Prerequisite]
    1. MATLAB (Tested on R2019a)
    2. CUDA Toolkit v10.0 

    [Optional: Installation instruction in SpikeForest environment]
    1. Run `git clone https://github.com/cortex-lab/KiloSort.git`
    3. In Matlab, run `CUDA/mexGPUall` to compile all CUDA codes
    4. Add `KILOSORT_PATH_DEV=...` in your .bashrc file.
    """

    NAME = 'KiloSort'
    VERSION = '0.2.4'  # wrapper VERSION
    ADDITIONAL_FILES = ['*.m']
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS',
        'OMP_NUM_THREADS'
    ]
    CONTAINER = None
    LOCAL_MODULES = ['../../spikeforest']

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    channels = mlpr.IntegerListParameter(
        description='List of channels to use.', optional=True, default=[])
    firings_out = mlpr.Output('Output firings file')

    detect_sign = mlpr.IntegerParameter(
        optional=True,
        default=-1,
        description=
        'Use -1 or 1, depending on the sign of the spikes in the recording')
    adjacency_radius = mlpr.FloatParameter(optional=True,
                                           default=-1,
                                           description='Currently unused')
    detect_threshold = mlpr.FloatParameter(optional=True,
                                           default=3,
                                           description='')
    # prm_template_name=mlpr.StringParameter(optional=False,description='TODO')
    freq_min = mlpr.FloatParameter(
        optional=True,
        default=300,
        description='Use 0 for no bandpass filtering')
    freq_max = mlpr.FloatParameter(
        optional=True,
        default=6000,
        description='Use 0 for no bandpass filtering')
    merge_thresh = mlpr.FloatParameter(optional=True,
                                       default=0.98,
                                       description='TODO')
    pc_per_chan = mlpr.IntegerParameter(optional=True,
                                        default=3,
                                        description='TODO')

    @staticmethod
    def install():
        print('Auto-installing kilosort.')
        return install_kilosort(
            repo='https://github.com/cortex-lab/KiloSort.git',
            commit='3f33771f8fdf8c3846a7f8a75cc8c318b44ed48c')

    def run(self):
        keep_temp_files = False
        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort-tmp-' + code

        try:
            recording = SFMdaRecordingExtractor(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting = kilosort_helper(recording=recording,
                                      tmpdir=tmpdir,
                                      detect_sign=self.detect_sign,
                                      adjacency_radius=self.adjacency_radius,
                                      detect_threshold=self.detect_threshold,
                                      merge_thresh=self.merge_thresh,
                                      freq_min=self.freq_min,
                                      freq_max=self.freq_max,
                                      pc_per_chan=self.pc_per_chan)
            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not keep_temp_files:
                    shutil.rmtree(tmpdir)
            raise
        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
コード例 #29
0
ファイル: klusta.py プロジェクト: samuelgarcia/spikeforest
class Klusta(mlpr.Processor):
    """

    Installation instruction
        >>> pip install Cython h5py tqdm
        >>> pip install click klusta klustakwik2

    More information on klusta at:
      * https://github.com/kwikteam/phy"
      * https://github.com/kwikteam/klusta

    """

    NAME = 'Klusta'
    VERSION = '0.2.2'  # wrapper VERSION
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS', 'TEMPDIR']
    # CONTAINER = 'sha1://6d76f22e3b4eff52b430ef4649a8802f7da9e0ec/2019-05-13/klusta.simg'
    CONTAINER = 'sha1://182ff734d38e2ece30ed751de55807b0a8359959/2019-06-28/klusta.simg'
    LOCAL_MODULES = ['../../spikeforest']

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    firings_out = mlpr.Output('Output firings file')
    adjacency_radius = mlpr.FloatParameter(optional=True, default=None, description='')
    detect_sign = mlpr.FloatParameter(optional=True, default=-1, description='')
    threshold_strong_std_factor = mlpr.FloatParameter(optional=True, default=5, description='')
    threshold_weak_std_factor = mlpr.FloatParameter(optional=True, default=2, description='')
    n_features_per_channel = mlpr.IntegerParameter(optional=True, default=3, description='')
    num_starting_clusters = mlpr.IntegerParameter(optional=True, default=3, description='')
    extract_s_before = mlpr.IntegerParameter(optional=True, default=16, description='')
    extract_s_after = mlpr.IntegerParameter(optional=True, default=32, description='')

    def run(self):
        import spikesorters as sorters

        print('Klusta......')
        recording = SFMdaRecordingExtractor(self.recording_dir)

        code = ''.join(random.choice(string.ascii_uppercase)
                       for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/klusta-' + code

        sorter = sorters.KlustaSorter(
            recording=recording,
            output_folder=tmpdir,
            debug=True,
            delete_output_folder=True
        )

        sorter.set_params(
            adjacency_radius=self.adjacency_radius,
            detect_sign=self.detect_sign,
            threshold_strong_std_factor=self.threshold_strong_std_factor,
            threshold_weak_std_factor=self.threshold_weak_std_factor,
            n_features_per_channel=self.n_features_per_channel,
            num_starting_clusters=self.num_starting_clusters,
            extract_s_before=self.extract_s_before,
            extract_s_after=self.extract_s_after
        )

        timer = sorter.run()
        print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))

        sorting = sorter.get_result()

        SFMdaSortingExtractor.write_sorting(
            sorting=sorting, save_path=self.firings_out)
コード例 #30
0
class MountainSort4(mlpr.Processor):
    NAME = 'MountainSort4'
    VERSION = '4.2.0'
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS',
        'OMP_NUM_THREADS'
    ]
    CONTAINER = 'sha1://009406add7a55687cec176be912bc7685c2a4b1d/02-12-2019/mountainsort4.simg'
    CONTAINER_SHARE_ID = '69432e9201d0'  # place to look for container

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    firings_out = mlpr.Output('Output firings file')

    detect_sign = mlpr.IntegerParameter(
        'Use -1, 0, or 1, depending on the sign of the spikes in the recording'
    )
    adjacency_radius = mlpr.FloatParameter(
        'Use -1 to include all channels in every neighborhood')
    freq_min = mlpr.FloatParameter(
        optional=True,
        default=300,
        description='Use 0 for no bandpass filtering')
    freq_max = mlpr.FloatParameter(
        optional=True,
        default=6000,
        description='Use 0 for no bandpass filtering')
    whiten = mlpr.BoolParameter(
        optional=True,
        default=True,
        description='Whether to do channel whitening as part of preprocessing')
    clip_size = mlpr.IntegerParameter(optional=True,
                                      default=50,
                                      description='')
    detect_threshold = mlpr.FloatParameter(optional=True,
                                           default=3,
                                           description='')
    detect_interval = mlpr.IntegerParameter(
        optional=True,
        default=10,
        description=
        'Minimum number of timepoints between events detected on the same channel'
    )
    noise_overlap_threshold = mlpr.FloatParameter(
        optional=True,
        default=0.15,
        description='Use None for no automated curation')

    def run(self):
        import spikeextractors as se
        import spiketoolkit as st
        import ml_ms4alg

        print('MountainSort4......')
        recording = se.MdaRecordingExtractor(self.recording_dir)
        num_workers = os.environ.get('NUM_WORKERS', None)
        if num_workers:
            num_workers = int(num_workers)

        # Bandpass filter
        if self.freq_min or self.freq_max:
            recording = st.preprocessing.bandpass_filter(
                recording=recording,
                freq_min=self.freq_min,
                freq_max=self.freq_max)

        # Whiten
        if self.whiten:
            recording = st.preprocessing.whiten(recording=recording)

        # Sort
        sorting = ml_ms4alg.mountainsort4(
            recording=recording,
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            clip_size=self.clip_size,
            detect_threshold=self.detect_threshold,
            detect_interval=self.detect_interval,
            num_workers=num_workers)

        # Curate
        # if self.noise_overlap_threshold is not None:
        #    sorting=ml_ms4alg.mountainsort4_curation(
        #      recording=recording,
        #      sorting=sorting,
        #      noise_overlap_threshold=self.noise_overlap_threshold
        #    )

        se.MdaSortingExtractor.writeSorting(sorting=sorting,
                                            save_path=self.firings_out)