Пример #1
0
class IronClust(mlpr.Processor):
    NAME='IronClust'
    VERSION='4.2.6'
    
    recording_dir=mlpr.Input('Directory of recording',directory=True)
    channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[])
    firings_out=mlpr.Output('Output firings file')
    
    detect_sign=mlpr.IntegerParameter('Use -1, 0, or 1, depending on the sign of the spikes in the recording')
    adjacency_radius=mlpr.FloatParameter('Use -1 to include all channels in every neighborhood')
    detect_threshold=mlpr.FloatParameter(optional=True,default=3,description='')
    prm_template_name=mlpr.StringParameter(optional=False,description='TODO')
    freq_min=mlpr.FloatParameter(optional=True,default=300,description='Use 0 for no bandpass filtering')
    freq_max=mlpr.FloatParameter(optional=True,default=6000,description='Use 0 for no bandpass filtering')
    merge_thresh=mlpr.FloatParameter(optional=True,default=0.98,description='TODO')
    pc_per_chan=mlpr.IntegerParameter(optional=True,default=3,description='TODO')
    
    def run(self):
        ironclust_src=os.environ.get('IRONCLUST_SRC',None)
        if not ironclust_src:
            raise Exception('Environment variable not set: IRONCLUST_SRC')
        code=''.join(random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir=os.environ.get('TEMPDIR','/tmp')+'/ironclust-tmp-'+code
            
        try:
            recording=si.MdaRecordingExtractor(self.recording_dir)
            if len(self.channels)>0:
              recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting=sf.sorters.ironclust(
                recording=recording,
                tmpdir=tmpdir, ## TODO
                detect_sign=self.detect_sign,
                adjacency_radius=self.adjacency_radius,
                detect_threshold=self.detect_threshold,
                merge_thresh=self.merge_thresh,
                freq_min=self.freq_min,
                freq_max=self.freq_max,
                pc_per_chan=self.pc_per_chan,
                prm_template_name=self.prm_template_name,
                ironclust_src=ironclust_src
            )
            si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                shutil.rmtree(tmpdir)
            raise
        shutil.rmtree(tmpdir)
Пример #2
0
class MeanShift(mlpr.Processor):
    VERSION = '0.1.0'
    data = mlpr.Input()
    labels_out = mlpr.Output(is_array=True)
    bandwidth = mlpr.StringParameter()

    def run(self):
        from sklearn.cluster import MeanShift
        import numpy as np
        if self.bandwidth == 'auto':
            bandwidth = None
        else:
            bandwidth = float(self.bandwidth)
        A = MeanShift(bandwidth=bandwidth).fit(np.load(self.data))
        np.save(self.labels_out + '.npy', A.labels_)
        os.rename(self.labels_out + '.npy', self.labels_out)
Пример #3
0
class H5ToDict(mlpr.Processor):
    NAME = 'H5ToDict'
    VERSION = '0.1.1'

    # Inputs
    h5_in = mlpr.Input()

    # Parameters
    upload_to = mlpr.StringParameter(optional=True, default='')

    # Outputs
    json_out = mlpr.Output()

    def run(self):
        upload_to = self.upload_to
        if not upload_to:
            upload_to = None
        x = h5_to_dict(self.h5_in, upload_to=upload_to)
        with open(self.json_out, 'w') as f:
            simplejson.dump(x, f, ignore_nan=True)
Пример #4
0
class JRClust(mlpr.Processor):
    """
    JRClust v4 wrapper
      written by J. James Jun, May 8, 2019
      modified from `spikeforest/spikesorters/ironclust/ironclust.py`

    [Installation instruction in SpikeForest environment]
    1. Run `git clone https://github.com/JaneliaSciComp/JRCLUST`
    2. Activate conda environment for SpikeForest
    3. Create `JRCLUST_PATH` and `MDAIO_PATH`
    4. Flatiron execution: `module load matlab/R2019a cuda/10.0.130_410.48`
    DO NOT LOAD `module load gcc`. 
    DO NOT USE `matlab/R2018b` (it will crash while creating a parpool).

    See: James Jun, et al. Real-time spike sorting platform for
    high-density extracellular probes with ground-truth validation and drift correction
    https://github.com/JaneliaSciComp/JRCLUST
    """

    NAME = 'JRClust'
    VERSION = '0.1.3'
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS',
        'OMP_NUM_THREADS', 'TEMPDIR'
    ]
    ADDITIONAL_FILES = ['*.m', '*.prm']
    CONTAINER = None

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    channels = mlpr.IntegerListParameter(
        description='List of channels to use.', optional=True, default=[])
    firings_out = mlpr.Output('Output firings file')

    detect_sign = mlpr.IntegerParameter(
        optional=True,
        default=-1,
        description=
        'Use -1, 0, or 1, depending on the sign of the spikes in the recording'
    )
    adjacency_radius = mlpr.FloatParameter(optional=True,
                                           default=50,
                                           description='')
    detect_threshold = mlpr.FloatParameter(optional=True,
                                           default=4.5,
                                           description='detection threshold')
    freq_min = mlpr.FloatParameter(
        optional=True,
        default=300,
        description='Use 0 for no bandpass filtering')
    freq_max = mlpr.FloatParameter(
        optional=True,
        default=3000,
        description='Use 0 for no bandpass filtering')
    merge_thresh = mlpr.FloatParameter(
        optional=True,
        default=0.98,
        description='Threshold for automated merging')
    pc_per_chan = mlpr.IntegerParameter(
        optional=True,
        default=1,
        description='Number of principal components per channel')

    # added in version 0.2.4
    filter_type = mlpr.StringParameter(
        optional=True,
        default='bandpass',
        description='{none, bandpass, wiener, fftdiff, ndiff}')
    nDiffOrder = mlpr.FloatParameter(optional=True, default=2, description='')
    common_ref_type = mlpr.StringParameter(optional=True,
                                           default='none',
                                           description='{none, mean, median}')
    min_count = mlpr.IntegerParameter(optional=True,
                                      default=30,
                                      description='Minimum cluster size')
    fGpu = mlpr.IntegerParameter(optional=True,
                                 default=1,
                                 description='Use GPU if available')
    fParfor = mlpr.IntegerParameter(optional=True,
                                    default=0,
                                    description='Use parfor if available')
    feature_type = mlpr.StringParameter(
        optional=True,
        default='gpca',
        description='{gpca, pca, vpp, vmin, vminmax, cov, energy, xcov}')

    def run(self):
        tmpdir = _get_tmpdir('jrclust')

        try:
            recording = SFMdaRecordingExtractor(self.recording_dir)
            params = read_dataset_params(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)

            all_params = dict()
            for param0 in self.PARAMETERS:
                all_params[param0.name] = getattr(self, param0.name)

            sorting = jrclust_helper(
                recording=recording,
                tmpdir=tmpdir,
                params=params,
                **all_params,
            )
            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not getattr(self, '_keep_temp_files', False):
                    shutil.rmtree(tmpdir)
            raise
        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
Пример #5
0
class IronClust(mlpr.Processor):
    NAME = 'IronClust'
    VERSION = '0.7.9'
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS', 'TEMPDIR']
    CONTAINER: Union[str, None] = None

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    firings_out = mlpr.Output('Output firings file')
    channels = mlpr.IntegerListParameter(
        optional=True, default=[],
        description='List of channels to use.'
    )
    detect_sign = mlpr.IntegerParameter(
        optional=True, default=-1,
        description='Use -1, 0, or 1, depending on the sign of the spikes in the recording'
    )
    adjacency_radius = mlpr.FloatParameter(
        optional=True, default=50,
        description='Use -1 to include all channels in every neighborhood'
    )
    adjacency_radius_out = mlpr.FloatParameter(
        optional=True, default=100,
        description='Use -1 to include all channels in every neighborhood'
    )
    detect_threshold = mlpr.FloatParameter(
        optional=True, default=4,
        description='detection threshold'
    )
    prm_template_name = mlpr.StringParameter(
        optional=True, default='',
        description='.prm template file name'
    )
    freq_min = mlpr.FloatParameter(
        optional=True, default=300,
        description='Use 0 for no bandpass filtering'
    )
    freq_max = mlpr.FloatParameter(
        optional=True, default=8000,
        description='Use 0 for no bandpass filtering'
    )
    merge_thresh = mlpr.FloatParameter(
        optional=True, default=0.98,
        description='Threshold for automated merging'
    )
    pc_per_chan = mlpr.IntegerParameter(
        optional=True, default=0,
        description='Number of principal components per channel'
    )  

    # added in version 0.2.4
    whiten = mlpr.BoolParameter(
        optional=True, default=False, description='Whether to do channel whitening as part of preprocessing')
    filter_type = mlpr.StringParameter(
        optional=True, default='bandpass', description='{none, bandpass, wiener, fftdiff, ndiff}')
    filter_detect_type = mlpr.StringParameter(
        optional=True, default='none', description='{none, bandpass, wiener, fftdiff, ndiff}')
    common_ref_type = mlpr.StringParameter(
        optional=True, default='mean', description='{none, mean, median}')
    batch_sec_drift = mlpr.FloatParameter(
        optional=True, default=300, description='batch duration in seconds. clustering time duration')
    step_sec_drift = mlpr.FloatParameter(
        optional=True, default=20, description='compute anatomical similarity every n sec')
    knn = mlpr.IntegerParameter(
        optional=True, default=30, description='K nearest neighbors')
    min_count = mlpr.IntegerParameter(
        optional=True, default=30, description='Minimum cluster size')
    fGpu = mlpr.BoolParameter(
        optional=True, default=True, description='Use GPU if available')
    fft_thresh = mlpr.FloatParameter(
        optional=True, default=8, description='FFT-based noise peak threshold')
    fft_thresh_low = mlpr.FloatParameter(
        optional=True, default=0, description='FFT-based noise peak lower threshold (set to 0 to disable dual thresholding scheme)')
    nSites_whiten = mlpr.IntegerParameter(
        optional=True, default=32, description='Number of adjacent channels to whiten')        
    feature_type = mlpr.StringParameter(
        optional=True, default='gpca', description='{gpca, pca, vpp, vmin, vminmax, cov, energy, xcov}')
    delta_cut = mlpr.FloatParameter(
        optional=True, default=1, description='Cluster detection threshold (delta-cutoff)')
    post_merge_mode = mlpr.IntegerParameter(
        optional=True, default=1, description='post_merge_mode')
    sort_mode = mlpr.IntegerParameter(
        optional=True, default=1, description='sort_mode')

    @staticmethod
    def install():
        print('Auto-installing ironclust.')
        return install_ironclust(commit='72bb6d097e0875d6cfe52bddf5f782e667e1b042')

    def run(self):
        timer = time.time()
        import spikesorters as sorters
        print('IronClust......')

        try:
            ironclust_path = IronClust.install()
        except:
            traceback.print_exc()
            raise Exception('Problem installing ironclust.')
        sorters.IronClustSorter.set_ironclust_path(ironclust_path)


        recording = SFMdaRecordingExtractor(self.recording_dir)
        code = ''.join(random.choice(string.ascii_uppercase)
                       for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/ironclust-' + code

        sorter = sorters.IronClustSorter(
            recording=recording,
            output_folder=tmpdir,
            debug=True,
            delete_output_folder = False # will be taken care by _keep_temp_files one step above
        )

        sorter.set_params(
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            adjacency_radius_out=self.adjacency_radius_out,
            detect_threshold=self.detect_threshold,
            prm_template_name=self.prm_template_name,
            freq_min=self.freq_min,
            freq_max=self.freq_max,
            merge_thresh=self.merge_thresh,
            pc_per_chan=self.pc_per_chan,
            whiten=self.whiten,
            filter_type=self.filter_type,
            filter_detect_type=self.filter_detect_type,
            common_ref_type=self.common_ref_type,
            batch_sec_drift=self.batch_sec_drift,
            step_sec_drift=self.step_sec_drift,
            knn=self.knn,
            min_count=self.min_count,
            fGpu=self.fGpu,
            fft_thresh=self.fft_thresh,
            fft_thresh_low=self.fft_thresh_low,
            nSites_whiten=self.nSites_whiten,
            feature_type=self.feature_type,
            delta_cut=self.delta_cut,
            post_merge_mode=self.post_merge_mode,
            sort_mode=self.sort_mode
        )     
        timer = sorter.run()
        print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
        sorting = sorter.get_result()

        SFMdaSortingExtractor.write_sorting(
            sorting=sorting, save_path=self.firings_out)
Пример #6
0
class IronClust(mlpr.Processor):
    NAME = 'IronClust'
    VERSION = '0.2.0'
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS',
        'OMP_NUM_THREADS'
    ]
    ADDITIONAL_FILES = ['*.m']
    CONTAINER = None
    CONTAINER_SHARE_ID = None

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    channels = mlpr.IntegerListParameter(
        description='List of channels to use.', optional=True, default=[])
    firings_out = mlpr.Output('Output firings file')

    detect_sign = mlpr.IntegerParameter(
        'Use -1, 0, or 1, depending on the sign of the spikes in the recording'
    )
    adjacency_radius = mlpr.FloatParameter(
        'Use -1 to include all channels in every neighborhood')
    detect_threshold = mlpr.FloatParameter(optional=True,
                                           default=5,
                                           description='detection threshold')
    prm_template_name = mlpr.StringParameter(
        optional=True, description='.prm template file name')
    freq_min = mlpr.FloatParameter(
        optional=True,
        default=300,
        description='Use 0 for no bandpass filtering')
    freq_max = mlpr.FloatParameter(
        optional=True,
        default=6000,
        description='Use 0 for no bandpass filtering')
    merge_thresh = mlpr.FloatParameter(optional=True,
                                       default=0.98,
                                       description='TODO')
    pc_per_chan = mlpr.IntegerParameter(optional=True,
                                        default=3,
                                        description='TODO')

    def run(self):
        ironclust_path = os.environ.get('IRONCLUST_PATH', None)
        if not ironclust_path:
            raise Exception('Environment variable not set: IRONCLUST_PATH')

        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/ironclust-tmp-' + code

        try:
            recording = se.MdaRecordingExtractor(self.recording_dir)
            params = read_dataset_params(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting = ironclust_helper(
                recording=recording,
                tmpdir=tmpdir,
                detect_sign=self.detect_sign,
                adjacency_radius=self.adjacency_radius,
                detect_threshold=self.detect_threshold,
                merge_thresh=self.merge_thresh,
                freq_min=self.freq_min,
                freq_max=self.freq_max,
                pc_per_chan=self.pc_per_chan,
                prm_template_name=self.prm_template_name,
                ironclust_path=ironclust_path,
                params=params,
            )
            se.MdaSortingExtractor.writeSorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not getattr(self, '_keep_temp_files', False):
                    shutil.rmtree(tmpdir)
            raise
        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)