class GenSortingComparisonTable(mlpr.Processor):
    VERSION = '0.2.6'
    firings = mlpr.Input('Firings file (sorting)')
    firings_true = mlpr.Input('True firings file')
    units_true = mlpr.IntegerListParameter('List of true units to consider')
    json_out = mlpr.Output(
        'Table as .json file produced from pandas dataframe')
    html_out = mlpr.Output(
        'Table as .html file produced from pandas dataframe')
    # CONTAINER = 'sha1://5627c39b9bd729fc011cbfce6e8a7c37f8bcbc6b/spikeforest_basic.simg'
    # CONTAINER = 'sha1://0944f052e22de0f186bb6c5cb2814a71f118f2d1/spikeforest_basic.simg'  # MAY26JJJ
    CONTAINER = 'sha1://4904b8f914eb159618b6579fb9ba07b269bb2c61/06-26-2019/spikeforest_basic.simg'

    def run(self):
        print(
            'GenSortingComparisonTable: firings={}, firings_true={}, units_true={}'
            .format(self.firings, self.firings_true, self.units_true))
        sorting = SFMdaSortingExtractor(firings_file=self.firings)
        sorting_true = SFMdaSortingExtractor(firings_file=self.firings_true)
        if (self.units_true is not None) and (len(self.units_true) > 0):
            sorting_true = si.SubSortingExtractor(parent_sorting=sorting_true,
                                                  unit_ids=self.units_true)

        SC = SortingComparison(sorting_true, sorting, delta_tp=30)
        df = get_comparison_data_frame(comparison=SC)
        # sw.SortingComparisonTable(comparison=SC).getDataframe()
        json = df.transpose().to_dict()
        html = df.to_html(index=False)
        _write_json_file(json, self.json_out)
        _write_json_file(html, self.html_out)
class ExtractTwoPhotonSeriesMp4(mlpr.Processor):
    NAME = 'H5ToDict'
    VERSION = '0.1.1'

    # Inputs
    nwb_in = mlpr.Input()

    # Outputs
    mp4_out = mlpr.Output()

    def run(self):
        nwb_obj = nwb_to_dict(self.nwb_in, use_cache=True)
        npy_path = nwb_obj['acquisition']['TwoPhotonSeries']['_datasets']['data']['_data']
        npy_path2 = mt.realizeFile(npy_path)
        if not npy_path2:
            nwb_obj = nwb_to_dict(self.nwb_in, use_cache=False)
            npy_path = nwb_obj['acquisition']['TwoPhotonSeries']['_datasets']['data']['_data']
            npy_path2 = mt.realizeFile(npy_path)
            if not npy_path2:
                self._set_error('Unable to realize npy file: {}'.format(npy_path))
                return
        X = np.load(npy_path2)

        # Note that there is a bug in imageio.mimwrite that prevents us to
        # write to a memory buffer.
        # See: https://github.com/imageio/imageio/issues/157

        imageio.mimwrite(self.mp4_out, X, format='mp4', fps=10)
Пример #3
0
class KiloSort(mlpr.Processor):
    NAME = 'KiloSort'
    VERSION = '0.2.0'  # wrapper VERSION
    ADDITIONAL_FILES = ['*.m']
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS']
    CONTAINER = None
    CONTAINER_SHARE_ID = None

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    channels = mlpr.IntegerListParameter(
        description='List of channels to use.', optional=True, default=[])
    firings_out = mlpr.Output('Output firings file')

    detect_sign = mlpr.IntegerParameter(
        'Use -1 or 1, depending on the sign of the spikes in the recording')
    adjacency_radius = mlpr.FloatParameter(
        'Use -1 to include all channels in every neighborhood')
    detect_threshold = mlpr.FloatParameter(
        optional=True, default=3, description='')
    # prm_template_name=mlpr.StringParameter(optional=False,description='TODO')
    freq_min = mlpr.FloatParameter(
        optional=True, default=300, description='Use 0 for no bandpass filtering')
    freq_max = mlpr.FloatParameter(
        optional=True, default=6000, description='Use 0 for no bandpass filtering')
    merge_thresh = mlpr.FloatParameter(
        optional=True, default=0.98, description='TODO')
    pc_per_chan = mlpr.IntegerParameter(
        optional=True, default=3, description='TODO')

    def run(self):
        code = ''.join(random.choice(string.ascii_uppercase)
                       for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp')+'/kilosort-tmp-'+code

        try:
            recording = se.MdaRecordingExtractor(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting = kilosort_helper(
                recording=recording,
                tmpdir=tmpdir,
                detect_sign=self.detect_sign,
                adjacency_radius=self.adjacency_radius,
                detect_threshold=self.detect_threshold,
                merge_thresh=self.merge_thresh,
                freq_min=self.freq_min,
                freq_max=self.freq_max,
                pc_per_chan=self.pc_per_chan
            )
            se.MdaSortingExtractor.writeSorting(
                sorting=sorting, save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                shutil.rmtree(tmpdir)
            raise
        shutil.rmtree(tmpdir)
class CreateTimeseriesPlot(mlpr.Processor):
    NAME = 'CreateTimeseriesPlot'
    VERSION = '0.1.7'
    recording_dir = mlpr.Input(directory=True,
                               description='Recording directory')
    channels = mlpr.IntegerListParameter(
        description='List of channels to use.', optional=True, default=[])
    jpg_out = mlpr.Output('The plot as a .jpg file')

    def run(self):
        R0 = si.MdaRecordingExtractor(dataset_directory=self.recording_dir,
                                      download=False)
        if len(self.channels) > 0:
            R0 = si.SubRecordingExtractor(parent_recording=R0,
                                          channel_ids=self.channels)
        R = sw.lazyfilters.bandpass_filter(recording=R0,
                                           freq_min=300,
                                           freq_max=6000)
        N = R.getNumFrames()
        N2 = int(N / 2)
        channels = R.getChannelIds()
        if len(channels) > 20: channels = channels[0:20]
        sw.TimeseriesWidget(recording=R,
                            trange=[N2 - 4000, N2 + 0],
                            channels=channels,
                            width=12,
                            height=5).plot()
        save_plot(self.jpg_out)
Пример #5
0
class ComputeUnitDetail(mlpr.Processor):
    NAME = 'ComputeUnitDetail'
    VERSION = '0.1.0'
    CONTAINER = None

    recording_dir = mlpr.Input(description='Recording directory',
                               optional=False,
                               directory=True)
    firings = mlpr.Input(description='Input firings.mda file')
    unit_id = mlpr.IntegerParameter(description='Unit ID')
    json_out = mlpr.Output(description='Output .json file')

    def run(self):
        recording = SFMdaRecordingExtractor(
            dataset_directory=self.recording_directory, download=True)
        sorting = SFMdaSortingExtractor(firings_file=self.firings)
        waveforms0 = _get_random_spike_waveforms(recording=recording,
                                                 sorting=sorting,
                                                 unit=self.unit_id)
        channel_ids = recording.get_channel_ids()
        avg_waveform = np.median(waveforms0, axis=2)
        ret = dict(channel_ids=channel_ids,
                   average_waveform=avg_waveform.tolist())
        with open(self.json_out, 'w') as f:
            json.dump(ret, f)
class CreateWaveformsPlot(mlpr.Processor):
  NAME='CreateWaveformsPlot'
  VERSION='0.1.0'
  recording_dir=mlpr.Input(directory=True,description='Recording directory')
  channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[])
  units=mlpr.IntegerListParameter(description='List of units to use.',optional=True,default=[])
  firings=mlpr.Input(description='Firings file')
  jpg_out=mlpr.Output('The plot as a .jpg file')
  
  def run(self):
    R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True)
    if len(self.channels)>0:
      R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channels)
    R=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000)
    S=si.MdaSortingExtractor(firings_file=self.firings)
    channels=R.getChannelIds()
    if len(channels)>20:
      channels=channels[0:20]
    if len(self.units)>0:
      units=self.units
    else:
      units=S.getUnitIds()
    if len(units)>20:
      units=units[::int(len(units)/20)]
    sw.UnitWaveformsWidget(recording=R,sorting=S,channels=channels,unit_ids=units).plot()
    save_plot(self.jpg_out)
class GenSortingComparisonTableNew(mlpr.Processor):
    VERSION = '0.3.1'
    firings = mlpr.Input('Firings file (sorting)')
    firings_true = mlpr.Input('True firings file')
    units_true = mlpr.IntegerListParameter('List of true units to consider')
    json_out = mlpr.Output(
        'Table as .json file produced from pandas dataframe')
    html_out = mlpr.Output(
        'Table as .html file produced from pandas dataframe')
    # CONTAINER = 'sha1://5627c39b9bd729fc011cbfce6e8a7c37f8bcbc6b/spikeforest_basic.simg'
    # CONTAINER = 'sha1://0944f052e22de0f186bb6c5cb2814a71f118f2d1/spikeforest_basic.simg'  # MAY26JJJ
    CONTAINER = 'sha1://4904b8f914eb159618b6579fb9ba07b269bb2c61/06-26-2019/spikeforest_basic.simg'

    def run(self):
        print(
            'GenSortingComparisonTable: firings={}, firings_true={}, units_true={}'
            .format(self.firings, self.firings_true, self.units_true))
        sorting = SFMdaSortingExtractor(firings_file=self.firings)
        sorting_true = SFMdaSortingExtractor(firings_file=self.firings_true)
        if (self.units_true is not None) and (len(self.units_true) > 0):
            sorting_true = si.SubSortingExtractor(parent_sorting=sorting_true,
                                                  unit_ids=self.units_true)

        SC = st.comparison.compare_sorter_to_ground_truth(
            gt_sorting=sorting_true,
            tested_sorting=sorting,
            delta_time=0.3,
            min_accuracy=0,
            compute_misclassification=False,
            exhaustive_gt=False  # Fix this in future
        )
        df = pd.concat([SC.count, SC.get_performance()], axis=1).reset_index()

        df = df.rename(columns=dict(gt_unit_id='unit_id',
                                    fp='num_false_positives',
                                    fn='num_false_negatives',
                                    tested_id='best_unit',
                                    tp='num_matches'))
        df['matched_unit'] = df['best_unit']
        df['f_p'] = 1 - df['precision']
        df['f_n'] = 1 - df['recall']

        # sw.SortingComparisonTable(comparison=SC).getDataframe()
        json = df.transpose().to_dict()
        html = df.to_html(index=False)
        _write_json_file(json, self.json_out)
        _write_json_file(html, self.html_out)
Пример #8
0
class ComputeUnitsInfo(mlpr.Processor):
    NAME = 'ComputeUnitsInfo'
    VERSION = '0.1.1'
    recording_dir = mlpr.Input(directory=True,
                               description='Recording directory')
    channel_ids = mlpr.IntegerListParameter(
        description='List of channels to use.', optional=True, default=[])
    unit_ids = mlpr.IntegerListParameter(description='List of units to use.',
                                         optional=True,
                                         default=[])
    firings = mlpr.Input(description='Firings file')
    json_out = mlpr.Output('The info as a .json file')

    def run(self):
        R0 = si.MdaRecordingExtractor(dataset_directory=self.recording_dir,
                                      download=True)
        if (self.channel_ids) and (len(self.channel_ids) > 0):
            R0 = si.SubRecordingExtractor(parent_recording=R0,
                                          channel_ids=self.channel_ids)
        recording = sw.lazyfilters.bandpass_filter(recording=R0,
                                                   freq_min=300,
                                                   freq_max=6000)
        sorting = si.MdaSortingExtractor(firings_file=self.firings)
        ef = int(1e6)
        recording_sub = si.SubRecordingExtractor(parent_recording=recording,
                                                 start_frame=0,
                                                 end_frame=ef)
        recording_sub = MemoryRecordingExtractor(
            parent_recording=recording_sub)
        sorting_sub = si.SubSortingExtractor(parent_sorting=sorting,
                                             start_frame=0,
                                             end_frame=ef)
        unit_ids = self.unit_ids
        if (not unit_ids) or (len(unit_ids) == 0):
            unit_ids = sorting.getUnitIds()

        channel_noise_levels = compute_channel_noise_levels(
            recording=recording)
        print('computing templates...')
        templates = compute_unit_templates(recording=recording_sub,
                                           sorting=sorting_sub,
                                           unit_ids=unit_ids)
        print('.')
        ret = []
        for i, unit_id in enumerate(unit_ids):
            template = templates[i]
            info0 = dict()
            info0['unit_id'] = int(unit_id)
            info0['snr'] = compute_template_snr(template, channel_noise_levels)
            peak_channel_index = np.argmax(np.max(np.abs(template), axis=1))
            info0['peak_channel'] = int(
                recording.getChannelIds()[peak_channel_index])
            train = sorting.getUnitSpikeTrain(unit_id=unit_id)
            info0['num_events'] = int(len(train))
            info0['firing_rate'] = float(
                len(train) /
                (recording.getNumFrames() / recording.getSamplingFrequency()))
            ret.append(info0)
        write_json_file(self.json_out, ret)
Пример #9
0
class MountainSort4(mlpr.Processor):
    NAME = 'MountainSort4'
    VERSION = '4.0.1'

    dataset_dir = mlpr.Input('Directory of dataset', directory=True)
    firings_out = mlpr.Output('Output firings file')

    detect_sign = mlpr.IntegerParameter(
        'Use -1, 0, or 1, depending on the sign of the spikes in the recording'
    )
    adjacency_radius = mlpr.FloatParameter(
        'Use -1 to include all channels in every neighborhood')
    freq_min = mlpr.FloatParameter(
        optional=True,
        default=300,
        description='Use 0 for no bandpass filtering')
    freq_max = mlpr.FloatParameter(
        optional=True,
        default=6000,
        description='Use 0 for no bandpass filtering')
    whiten = mlpr.BoolParameter(
        optional=True,
        default=True,
        description='Whether to do channel whitening as part of preprocessing')
    clip_size = mlpr.IntegerParameter(optional=True,
                                      default=50,
                                      description='')
    detect_threshold = mlpr.FloatParameter(optional=True,
                                           default=3,
                                           description='')
    detect_interval = mlpr.IntegerParameter(
        optional=True,
        default=10,
        description=
        'Minimum number of timepoints between events detected on the same channel'
    )
    noise_overlap_threshold = mlpr.FloatParameter(
        optional=True,
        default=0.15,
        description='Use None for no automated curation')

    def run(self):
        recording = si.MdaRecordingExtractor(self.dataset_dir)
        num_workers = int(os.environ.get('NUM_WORKERS', -1))
        if num_workers <= 0: num_workers = None
        sorting = sf.sorters.mountainsort4(
            recording=recording,
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            freq_min=self.freq_min,
            freq_max=self.freq_max,
            whiten=self.whiten,
            clip_size=self.clip_size,
            detect_threshold=self.detect_threshold,
            detect_interval=self.detect_interval,
            noise_overlap_threshold=self.noise_overlap_threshold,
            num_workers=num_workers)
        si.MdaSortingExtractor.writeSorting(sorting=sorting,
                                            save_path=self.firings_out)
class CreateEfficientAccessRecordingFile(mlpr.Processor):
    NAME = 'CreateEfficientAccessRecordingFile'
    VERSION = '0.1.2'
    recording = mlpr.Input()
    segment_size = mlpr.IntegerParameter(optional=True, default=1000000)
    hdf5_out = mlpr.Output()

    def run(self):
        import h5py
        recording = self.recording
        segment_size = self.segment_size
        channel_ids = recording.get_channel_ids()
        samplerate = recording.get_sampling_frequency()
        M = len(channel_ids)  # number of channels
        N = recording.get_num_frames()  # Number of timepoints
        num_segments = int(np.ceil(N / segment_size))

        try:
            channel_locations = recording.get_channel_locations(
                channel_ids=channel_ids)
            nd = len(channel_locations[0])
            geom = np.zeros((M, nd))
            for m in range(M):
                geom[m, :] = channel_locations[m]
        except:
            nd = 2
            geom = np.zeros((M, nd))

        with h5py.File(self.hdf5_out, "w") as f:
            f.create_dataset('segment_size', data=[segment_size])
            f.create_dataset('num_segments', data=[num_segments])
            f.create_dataset('num_channels', data=[M])
            f.create_dataset('channel_ids', data=np.array(channel_ids))
            f.create_dataset('num_timepoints', data=[N])
            f.create_dataset('samplerate', data=[samplerate])
            f.create_dataset('geom', data=geom)
            if callable(recording.hash):
                hash0 = recording.hash()
            else:
                hash0 = recording.hash
            f.create_dataset('recording_hash', data=np.array([hash0.encode()]))
            for j in range(num_segments):
                segment = np.zeros((M, segment_size),
                                   dtype=float)  # fix dtype here
                t1 = int(j * segment_size)  # first timepoint of the segment
                t2 = int(np.minimum(
                    N, (t1 + segment_size)))  # last timepoint of segment (+1)
                s1 = int(np.maximum(0, t1))  # first timepoint
                s2 = int(np.minimum(N, t2))  # last timepoint (+1)

                # determine aa so that t1-s1+aa = 0
                # so, aa = -(t1-s1)
                aa = -(t1 - s1)
                segment[:, aa:aa + s2 - s1] = recording.get_traces(
                    start_frame=s1, end_frame=s2)  # Read the segment

                for ii, ch in enumerate(channel_ids):
                    f.create_dataset('part-{}-{}'.format(ch, j),
                                     data=segment[ii, :].ravel())
Пример #11
0
class Waveclus(mlpr.Processor):
    """
    Wave_clus wrapper
      written by J. James Jun, May 21, 2019

    [Optional: Installation instruction in SpikeForest environment]
    1. Run `git clone https://github.com/csn-le/wave_clus.git`
    2. Activate conda environment for SpikeForest
    3. Create `WAVECLUS_PATH_DEV`

    Algorithm website:
    https://github.com/csn-le/wave_clus/wiki
    """

    NAME = 'waveclus'
    VERSION = '0.0.5'
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS',
        'OMP_NUM_THREADS', 'TEMPDIR'
    ]
    ADDITIONAL_FILES = ['*.m', '*.prm']
    CONTAINER = None
    LOCAL_MODULES = ['../../spikeforest']

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    firings_out = mlpr.Output('Output firings file')

    def run(self):
        tmpdir = _get_tmpdir('waveclus')

        try:
            recording = SFMdaRecordingExtractor(self.recording_dir)
            params = read_dataset_params(self.recording_dir)
            # if len(self.channels) > 0:
            #     recording = se.SubRecordingExtractor(
            #         parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)

            all_params = dict()
            for param0 in self.PARAMETERS:
                all_params[param0.name] = getattr(self, param0.name)
            sorting = waveclus_helper(
                recording=recording,
                tmpdir=tmpdir,
                params=params,
                **all_params,
            )
            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not getattr(self, '_keep_temp_files', False):
                    print('erased temp file 1')
                    shutil.rmtree(tmpdir)
            raise
        if not getattr(self, '_keep_temp_files', False):
            print('erased temp file 2')
            shutil.rmtree(tmpdir)
Пример #12
0
class ComputeMandelbrotWithError(mlpr.Processor):
    NAME = 'ComputeMandelbrotWithError'
    VERSION = '0.1.4'

    xmin = mlpr.IntegerParameter('The minimum x value',
                                 optional=True,
                                 default=-2)
    xmax = mlpr.IntegerParameter('The maximum x value',
                                 optional=True,
                                 default=0.5)
    ymin = mlpr.IntegerParameter('The minimum y value',
                                 optional=True,
                                 default=-1.25)
    ymax = mlpr.IntegerParameter('The maximum y value',
                                 optional=True,
                                 default=1.25)
    num_x = mlpr.IntegerParameter(
        'The number of points (resolution) in the x dimension',
        optional=True,
        default=1000)
    num_iter = mlpr.IntegerParameter('Number of iterations',
                                     optional=True,
                                     default=1000)
    subsampling_factor = mlpr.IntegerParameter(
        'Subsampling factor (1 means no subsampling)',
        optional=True,
        default=1)
    subsampling_offset = mlpr.IntegerParameter('Subsampling offset',
                                               optional=True,
                                               default=0)
    throw_error = mlpr.BoolParameter(
        'Whether to intentionally throw an error for testing purposes',
        optional=True,
        default=False)

    output_npy = mlpr.Output('The output .npy file.')

    def __init__(self):
        mlpr.Processor.__init__(self)

    def run(self):
        import time
        if self.throw_error:
            print('Intentionally throwing error in 2 seconds...')
            time.sleep(2)
            raise Exception('Intentionally throwing error.')
        if self.subsampling_factor > 1:
            print('Using subsampling factor {}, offset {}'.format(
                self.subsampling_factor, self.subsampling_offset))
        X = compute_mandelbrot(xmin=self.xmin,
                               xmax=self.xmax,
                               ymin=self.ymin,
                               ymax=self.ymax,
                               num_x=self.num_x,
                               num_iter=self.num_iter,
                               subsampling_factor=self.subsampling_factor,
                               subsampling_offset=self.subsampling_offset)
        np.save(self.output_npy, X)
Пример #13
0
class GenSortingComparisonTable(mlpr.Processor):
    VERSION='0.1.1'
    firings=mlpr.Input('Firings file (sorting)')
    firings_true=mlpr.Input('True firings file')
    units_true=mlpr.IntegerListParameter('List of true units to consider')
    json_out=mlpr.Output('Table as .json file produced from pandas dataframe')
    html_out=mlpr.Output('Table as .html file produced from pandas dataframe')
    
    def run(self):
        sorting=si.MdaSortingExtractor(firings_file=self.firings)
        sorting_true=si.MdaSortingExtractor(firings_file=self.firings_true)
        if len(self.units_true)>0:
            sorting_true=si.SubSortingExtractor(parent_sorting=sorting_true,unit_ids=self.units_true)
        SC=st.comparison.SortingComparison(sorting_true,sorting)
        df=sw.SortingComparisonTable(comparison=SC).getDataframe()
        json=df.transpose().to_dict()
        html=df.to_html(index=False)
        _write_json_file(json,self.json_out)
        _write_json_file(html,self.html_out)
class GenerateMearecRecording(mlpr.Processor):
    NAME = "GenerateMearecRecording"
    VERSION = "0.1.0"

    # input file
    templates_in = mlpr.Input(description='.h5 file containing templates')

    # output file
    recording_out = mlpr.Output()

    # recordings params
    drifting = mlpr.BoolParameter()
    noise_level = mlpr.FloatParameter()
    bursting = mlpr.BoolParameter()
    shape_mod = mlpr.BoolParameter()

    # spiketrains params
    duration = mlpr.FloatParameter()
    n_exc = mlpr.IntegerParameter()
    n_inh = mlpr.IntegerParameter()

    # templates params
    min_dist = mlpr.FloatParameter()

    # seed
    seed = mlpr.IntegerParameter()

    def run(self):
        recordings_params = deepcopy(mr.get_default_recordings_params())

        recordings_params['recordings']['drifting'] = self.drifting
        recordings_params['recordings']['noise_level'] = self.noise_level
        recordings_params['recordings']['bursting'] = self.bursting
        recordings_params['recordings']['shape_mod'] = self.shape_mod
        recordings_params['recordings']['seed'] = self.seed
        # recordings_params['recordings']['chunk_conv_duration'] = 0  # turn off parallel execution

        recordings_params['spiketrains']['duration'] = self.duration
        recordings_params['spiketrains']['n_exc'] = self.n_exc
        recordings_params['spiketrains']['n_inh'] = self.n_inh
        recordings_params['spiketrains']['seed'] = self.seed

        recordings_params['templates']['min_dist'] = self.min_dist
        recordings_params['templates']['seed'] = self.seed

        # this is needed because mr.load_templates requires the file extension
        templates_fname = self.templates_in + '.h5'
        shutil.copyfile(self.templates_in, templates_fname)
        tempgen = mr.load_templates(Path(templates_fname))

        recgen = mr.gen_recordings(params=recordings_params,
                                   tempgen=tempgen,
                                   verbose=False)
        mr.save_recording_generator(recgen, self.recording_out)
        del recgen
Пример #15
0
class YASS(mlpr.Processor):
    NAME = 'YASS'
    VERSION = '0.1.0'
    # used by container to pass the env variables
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS']
    ADDITIONAL_FILES = ['*.yaml']
    CONTAINER = 'sha1://087767605e10761331699dda29519444bbd823f4/02-12-2019/yass.simg'
    CONTAINER_SHARE_ID = '69432e9201d0'  # place to look for container

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    channels = mlpr.IntegerListParameter(
        description='List of channels to use.', optional=True, default=[])
    firings_out = mlpr.Output('Output firings file')
    #paramfile_out = mlpr.Output('YASS yaml config file')

    detect_sign = mlpr.IntegerParameter(description='-1, 1, or 0')
    adjacency_radius = mlpr.FloatParameter(
        optional=True, default=100, description='Channel neighborhood adjacency radius corresponding to geom file')
    template_width_ms = mlpr.FloatParameter(
        optional=True, default=3, description='Spike width in milliseconds')
    filter = mlpr.BoolParameter(optional=True, default=True)

    def run(self):
        code = ''.join(random.choice(string.ascii_uppercase)
                       for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp')+'/yass-tmp-'+code

        #num_workers = os.environ.get('NUM_WORKERS', 1)
        #print('num_workers: {}'.format(num_workers))
        try:
            recording = se.MdaRecordingExtractor(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting, yaml_file = yass_helper(
                recording=recording,
                output_folder=tmpdir,
                probe_file=None,
                file_name=None,
                detect_sign=self.detect_sign,
                adjacency_radius=self.adjacency_radius,
                template_width_ms=self.template_width_ms,
                filter=self.filter)
            se.MdaSortingExtractor.writeSorting(
                sorting=sorting, save_path=self.firings_out)
            #shutil.copyfile(yaml_file, self.paramfile_out)
        except:
            if os.path.exists(tmpdir):
                shutil.rmtree(tmpdir)
            raise
        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
Пример #16
0
class ComputeUnitsInfo(mlpr.Processor):
    NAME = 'ComputeUnitsInfo'
    VERSION = '0.1.3'
    recording = mlpr.Input()
    sorting = mlpr.Input()
    json_out = mlpr.Output()

    def run(self):
        info0 = sa.compute_units_info(recording=self.recording, sorting=self.sorting)
        with open(self.json_out, 'w') as f:
            json.dump(info0, f)
Пример #17
0
class AffinityPropagation(mlpr.Processor):
    VERSION = '0.1.0'
    data = mlpr.Input()
    labels_out = mlpr.Output(is_array=True)
    damping = mlpr.FloatParameter()

    def run(self):
        from sklearn.cluster import AffinityPropagation
        import numpy as np
        A = AffinityPropagation(damping=self.damping).fit(np.load(self.data))
        np.save(self.labels_out + '.npy', A.labels_)
        os.rename(self.labels_out + '.npy', self.labels_out)
Пример #18
0
class ComputeDatasetInfo(mlpr.Processor):
  NAME='ComputeDatasetInfo'
  VERSION='0.1.0'
  recording_dir=mlpr.Input(directory=True,description='Recording directory')
  json_out=mlpr.Output('Info in .json file')
    
  def run(self):
    ret={}
    recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False)
    ret['samplerate']=recording.getSamplingFrequency()
    ret['num_channels']=len(recording.getChannelIds())
    ret['duration_sec']=recording.getNumFrames()/ret['samplerate']
    write_json_file(self.json_out,ret)
Пример #19
0
class ComputeUnitsInfo(mlpr.Processor):
  NAME='ComputeUnitsInfo'
  VERSION='0.1.5k'
  CONTAINER=_CONTAINER
  recording_dir=mlpr.Input(directory=True,description='Recording directory')
  channel_ids=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[])
  unit_ids=mlpr.IntegerListParameter(description='List of units to use.',optional=True,default=[])
  firings=mlpr.Input(description='Firings file')
  json_out=mlpr.Output('The info as a .json file')
  
  def run(self):
    import spikewidgets as sw
    
    R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True)
    if (self.channel_ids) and (len(self.channel_ids)>0):
      R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channel_ids)
    
    recording = R0
    # recording=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000)

    sorting=si.MdaSortingExtractor(firings_file=self.firings)
    unit_ids=self.unit_ids
    if (not unit_ids) or (len(unit_ids)==0):
      unit_ids=sorting.getUnitIds()
  
    channel_noise_levels=compute_channel_noise_levels(recording=recording)

    # No longer use subset to compute the templates
    templates=compute_unit_templates(recording=recording,sorting=sorting,unit_ids=unit_ids,max_num=100)

    ret=[]
    for i,unit_id in enumerate(unit_ids):
      template=templates[i]
      max_p2p_amps_on_channels=np.max(template,axis=1)-np.min(template,axis=1)
      peak_channel_index=np.argmax(max_p2p_amps_on_channels)
      peak_channel=recording.getChannelIds()[peak_channel_index]
      R1=si.SubRecordingExtractor(parent_recording=recording,channel_ids=[peak_channel_index])
      R1f=sw.lazyfilters.bandpass_filter(recording=R1,freq_min=300,freq_max=6000)
      templates2=compute_unit_templates(recording=R1f,sorting=sorting,unit_ids=[unit_id],max_num=100)
      template2=templates2[0]
      info0=dict()
      info0['unit_id']=int(unit_id)
      info0['snr']=np.max(np.abs(template2))/channel_noise_levels[peak_channel_index]
      #info0['snr']=compute_template_snr(template,channel_noise_levels)
      #peak_channel_index=np.argmax(np.max(np.abs(template),axis=1))
      info0['peak_channel']=int(recording.getChannelIds()[peak_channel])
      train=sorting.getUnitSpikeTrain(unit_id=unit_id)
      info0['num_events']=int(len(train))
      info0['firing_rate']=float(len(train)/(recording.getNumFrames()/recording.getSamplingFrequency()))
      ret.append(info0)
    write_json_file(self.json_out,ret)
Пример #20
0
class ComputeMandelbrot(mlpr.Processor):
    NAME = 'ComputeMandelbrot'
    VERSION = '0.1.4'

    xmin = mlpr.IntegerParameter('The minimum x value',
                                 optional=True,
                                 default=-2)
    xmax = mlpr.IntegerParameter('The maximum x value',
                                 optional=True,
                                 default=0.5)
    ymin = mlpr.IntegerParameter('The minimum y value',
                                 optional=True,
                                 default=-1.25)
    ymax = mlpr.IntegerParameter('The maximum y value',
                                 optional=True,
                                 default=1.25)
    num_x = mlpr.IntegerParameter(
        'The number of points (resolution) in the x dimension',
        optional=True,
        default=1000)
    num_iter = mlpr.IntegerParameter('Number of iterations',
                                     optional=True,
                                     default=1000)
    subsampling_factor = mlpr.IntegerParameter(
        'Subsampling factor (1 means no subsampling)',
        optional=True,
        default=1)
    subsampling_offset = mlpr.IntegerParameter('Subsampling offset',
                                               optional=True,
                                               default=0)

    output_npy = mlpr.Output('The output .npy file.')

    def __init__(self):
        mlpr.Processor.__init__(self)

    def run(self):
        print('=== ComputeMandelbrot ===', self.subsampling_factor,
              self.subsampling_offset)
        if self.subsampling_factor > 1:
            print('Using subsampling factor {}, offset {}'.format(
                self.subsampling_factor, self.subsampling_offset))
        X = compute_mandelbrot(xmin=self.xmin,
                               xmax=self.xmax,
                               ymin=self.ymin,
                               ymax=self.ymax,
                               num_x=self.num_x,
                               num_iter=self.num_iter,
                               subsampling_factor=self.subsampling_factor,
                               subsampling_offset=self.subsampling_offset)
        np.save(self.output_npy, X)
Пример #21
0
class ComputeNthPrime(mlpr.Processor):
    NAME = 'ComputeNthPrime'
    VERSION = '0.1.1'

    n = mlpr.IntegerParameter('The integer n.')
    output = mlpr.Output('The output text file.')

    def __init__(self):
        mlpr.Processor.__init__(self)

    def run(self):
        prime = nth_prime_number(self.n)
        with open(self.output, 'w') as f:
            f.write('{}'.format(prime))
Пример #22
0
class ComputeNPrimes(mlpr.Processor):
    NAME = 'ComputeNPrimes'
    VERSION = '0.1.3'

    n = mlpr.IntegerParameter('The integer n.')
    output = mlpr.Output('The output .npy file.')

    def __init__(self):
        mlpr.Processor.__init__(self)

    def run(self):
        primes = compute_n_primes(self.n)
        print('Prime {}: {}'.format(self.n, primes[-1]))
        np.save(self.output, primes)
Пример #23
0
class PlotUnitWaveforms(mlpr.Processor):
    VERSION='0.1.0'
    recording_dir=mlpr.Input(directory=True,description='Recording directory')
    channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[])
    firings=mlpr.Input('Firings file (sorting)')
    plot_out=mlpr.Output('Plot as .jpg image file')
    
    def run(self):
        recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir)
        if len(self.channels)>0:
            recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
        sorting=si.MdaSortingExtractor(firings_file=self.firings)
        sw.UnitWaveformsWidget(recording=recording,sorting=sorting).plot()
        fname=save_plot(self.plot_out)
Пример #24
0
class SpykingCircus(mlpr.Processor):
    NAME='SpykingCircus'
    VERSION='0.1.2'
    
    recording_dir=mlpr.Input('Directory of recording',directory=True)
    channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[])
    firings_out=mlpr.Output('Output firings file')
    
    detect_sign=mlpr.IntegerParameter(description='-1, 1, or 0')
    adjacency_radius=mlpr.FloatParameter(optional=True,default=100,description='Channel neighborhood adjacency radius corresponding to geom file')
    spike_thresh=mlpr.FloatParameter(optional=True,default=6,description='Threshold for detection')
    template_width_ms=mlpr.FloatParameter(optional=True,default=3,description='Spyking circus parameter')
    filter=mlpr.BoolParameter(optional=True,default=True)
    whitening_max_elts=mlpr.IntegerParameter(optional=True,default=1000,description='I believe it relates to subsampling and affects compute time')
    clustering_max_elts=mlpr.IntegerParameter(optional=True,default=10000,description='I believe it relates to subsampling and affects compute time')
    
    def run(self):
        code=''.join(random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir=os.environ.get('TEMPDIR','/tmp')+'/ironclust-tmp-'+code
        
        num_workers=os.environ.get('NUM_WORKERS',2)
            
        try:
            recording=si.MdaRecordingExtractor(self.recording_dir)
            if len(self.channels)>0:
              recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting=sf.sorters.spyking_circus(
                recording=recording,
                output_folder=tmpdir,
                probe_file=None,
                file_name=None,
                detect_sign=self.detect_sign,
                adjacency_radius=self.adjacency_radius,
                spike_thresh=self.spike_thresh,
                template_width_ms=self.template_width_ms,
                filter=self.filter,
                merge_spikes=True,
                n_cores=num_workers,
                electrode_dimensions=None,
                whitening_max_elts=self.whitening_max_elts,
                clustering_max_elts=self.clustering_max_elts
            )
            si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                shutil.rmtree(tmpdir)
            raise
        shutil.rmtree(tmpdir)
Пример #25
0
class IronClust(mlpr.Processor):
    NAME='IronClust'
    VERSION='4.2.6'
    
    recording_dir=mlpr.Input('Directory of recording',directory=True)
    channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[])
    firings_out=mlpr.Output('Output firings file')
    
    detect_sign=mlpr.IntegerParameter('Use -1, 0, or 1, depending on the sign of the spikes in the recording')
    adjacency_radius=mlpr.FloatParameter('Use -1 to include all channels in every neighborhood')
    detect_threshold=mlpr.FloatParameter(optional=True,default=3,description='')
    prm_template_name=mlpr.StringParameter(optional=False,description='TODO')
    freq_min=mlpr.FloatParameter(optional=True,default=300,description='Use 0 for no bandpass filtering')
    freq_max=mlpr.FloatParameter(optional=True,default=6000,description='Use 0 for no bandpass filtering')
    merge_thresh=mlpr.FloatParameter(optional=True,default=0.98,description='TODO')
    pc_per_chan=mlpr.IntegerParameter(optional=True,default=3,description='TODO')
    
    def run(self):
        ironclust_src=os.environ.get('IRONCLUST_SRC',None)
        if not ironclust_src:
            raise Exception('Environment variable not set: IRONCLUST_SRC')
        code=''.join(random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir=os.environ.get('TEMPDIR','/tmp')+'/ironclust-tmp-'+code
            
        try:
            recording=si.MdaRecordingExtractor(self.recording_dir)
            if len(self.channels)>0:
              recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting=sf.sorters.ironclust(
                recording=recording,
                tmpdir=tmpdir, ## TODO
                detect_sign=self.detect_sign,
                adjacency_radius=self.adjacency_radius,
                detect_threshold=self.detect_threshold,
                merge_thresh=self.merge_thresh,
                freq_min=self.freq_min,
                freq_max=self.freq_max,
                pc_per_chan=self.pc_per_chan,
                prm_template_name=self.prm_template_name,
                ironclust_src=ironclust_src
            )
            si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                shutil.rmtree(tmpdir)
            raise
        shutil.rmtree(tmpdir)
Пример #26
0
class PlotAutoCorrelograms(mlpr.Processor):
    NAME='spikeforest.PlotAutoCorrelograms'
    VERSION='0.1.0'
    recording_dir=mlpr.Input(directory=True,description='Recording directory')
    channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[])
    firings=mlpr.Input('Firings file (sorting)')
    plot_out=mlpr.Output('Plot as .jpg image file')
    
    def run(self):
        recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False)
        if len(self.channels)>0:
            recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
        sorting=si.MdaSortingExtractor(firings_file=self.firings)
        sw.CrossCorrelogramsWidget(samplerate=recording.getSamplingFrequency(),sorting=sorting).plot()
        fname=save_plot(self.plot_out)
Пример #27
0
class RepeatText(mlpr.Processor):
    textfile = mlpr.Input(help="input text file")
    textfile_out = mlpr.Output(help="output text file")
    num_repeats = mlpr.IntegerListParameter(
        help="Number of times to repeat the text")

    def run(self):
        assert self.num_repeats >= 0
        with open(self.textfile, 'r') as f:
            txt = f.read()
        txt2 = ''
        for _ in range(self.num_repeats):
            txt2 = txt2 + txt
        with open(self.textfile_out, 'w') as f:
            f.write(txt2)
Пример #28
0
class MeanShift(mlpr.Processor):
    VERSION = '0.1.0'
    data = mlpr.Input()
    labels_out = mlpr.Output(is_array=True)
    bandwidth = mlpr.StringParameter()

    def run(self):
        from sklearn.cluster import MeanShift
        import numpy as np
        if self.bandwidth == 'auto':
            bandwidth = None
        else:
            bandwidth = float(self.bandwidth)
        A = MeanShift(bandwidth=bandwidth).fit(np.load(self.data))
        np.save(self.labels_out + '.npy', A.labels_)
        os.rename(self.labels_out + '.npy', self.labels_out)
Пример #29
0
class ComputeRecordingInfo(mlpr.Processor):
  NAME='ComputeRecordingInfo'
  VERSION='0.1.0'
  recording_dir=mlpr.Input(directory=True,description='Recording directory')
  channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[])
  json_out=mlpr.Output('Info in .json file')
    
  def run(self):
    ret={}
    recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False)
    if len(self.channels)>0:
      recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
    ret['samplerate']=recording.getSamplingFrequency()
    ret['num_channels']=len(recording.getChannelIds())
    ret['duration_sec']=recording.getNumFrames()/ret['samplerate']
    write_json_file(self.json_out,ret)
Пример #30
0
class CombineSubsampledMandelbrot(mlpr.Processor):
    NAME = 'CombineSubsampledMandelbrot'
    VERSION = '0.1.1'

    X_list = mlpr.Input(multi=True)
    X_out = mlpr.Output()
    num_x = mlpr.IntegerParameter()

    def run(self):
        print('=== CombineSubsampledMandelbrot ===', self.num_x)
        self.X_list
        arrays = []
        for X0 in self.X_list:
            arrays.append(np.load(X0))
        X = combine_subsampled_mandelbrot(arrays)
        X = X[:self.num_x, :]
        np.save(self.X_out, X)