Esempio n. 1
0
def load_ntrode(path, *, name, epoch_name, processed_path=None):
    # use the .geom.csv if it exists (we assume path ends with .mda)
    geom_file = path[0:-4] + '.geom.csv'
    if mt.findFile(geom_file):
        print('Using geometry file: {}'.format(geom_file))
    else:
        # if doesn't exist, we will create a trivial geom later
        geom_file = None
    
    path2 = mt.realizeFile(path)
    if not path2:
        raise Exception('Unable to realize file: ' + path)
    
    X = mdaio.DiskReadMda(path2)
    num_channels = X.N1()
    num_timepoints = X.N2()

    processed_info = load_ntrode_processed_info(processed_path, recording_path=path, epoch_name=epoch_name, ntrode_name=name)

    # here's the structure for representing ntrode information
    return dict(
        type='ntrode',
        name=name,
        epoch_name=epoch_name,
        path=path,
        processed_path=processed_path,
        recording_file=path,
        geom_file=geom_file,
        num_channels=num_channels,
        num_timepoints=num_timepoints,
        samplerate=30000,  # fix this
        processed_info=processed_info
    )
Esempio n. 2
0
    def __init__(self, dataset_directory, *, download=True, raw_fname='raw.mda', params_fname='params.json'):
        RecordingExtractor.__init__(self)
        self._dataset_directory = dataset_directory
        self._timeseries_path = dataset_directory + '/' + raw_fname
        self._dataset_params = read_dataset_params(dataset_directory, params_fname)
        self._samplerate = self._dataset_params['samplerate'] * 1.0
        if download:
            path0 = mt.realizeFile(path=self._timeseries_path)
            if not path0:
                raise Exception('Unable to realize file: ' + self._timeseries_path)
            self._timeseries_path = path0

        geom0 = dataset_directory + '/geom.csv'
        self._geom_fname = mt.realizeFile(path=geom0)
        self._geom = np.genfromtxt(self._geom_fname, delimiter=',')

        timeseries_path_or_url = self._timeseries_path
        if not mt.isLocalPath(timeseries_path_or_url):
            a = mt.findFile(timeseries_path_or_url)
            if not a:
                raise Exception('Cannot find timeseries file: ' + timeseries_path_or_url)
            timeseries_path_or_url = a

        # if is_kbucket_url(timeseries0):
        #     download_needed = is_url(ca.findFile(path=timeseries0))
        # else:
        #     download_needed = is_url(timeseries0)
        # if download and download_needed:
        #     print('Downloading file: ' + timeseries0)
        #     self._timeseries_path = ca.realizeFile(path=timeseries0)
        #     print('Done.')
        # else:
        #     self._timeseries_path = ca.findFile(path=timeseries0)

        X = DiskReadMda(timeseries_path_or_url)
        if self._geom.shape[0] != X.N1():
            # raise Exception(
            #    'Incompatible dimensions between geom.csv and timeseries file {} <> {}'.format(self._geom.shape[0], X.N1()))
            print('WARNING: Incompatible dimensions between geom.csv and timeseries file {} <> {}'.format(self._geom.shape[0], X.N1()))
            self._geom = np.zeros((X.N1(), 2))

        self._num_channels = X.N1()
        self._num_timepoints = X.N2()
        for m in range(self._num_channels):
            self.set_channel_property(m, 'location', self._geom[m, :])
Esempio n. 3
0
def load_sorting_results_info(firings_path, *, recording_path, epoch_name, ntrode_name, curated=False):
    if not mt.findFile(firings_path):
        return None
    sorting = SFMdaSortingExtractor(firings_file=firings_path)
    total_num_events = 0
    for unit_id in sorting.get_unit_ids():
        spike_times = sorting.get_unit_spike_train(unit_id=unit_id)
        total_num_events = total_num_events + len(spike_times)
    return dict(
        type='sorting_results',
        epoch_name=epoch_name,
        ntrode_name=ntrode_name,
        curated=curated,
        firings_path=firings_path,
        recording_path=recording_path,
        unit_ids=sorting.get_unit_ids(),
        num_events=total_num_events
    )
Esempio n. 4
0
    def __init__(self, firings_file):
        SortingExtractor.__init__(self)
        if is_kbucket_url(firings_file):
            download_needed = is_url(mt.findFile(path=firings_file))
        else:
            download_needed = is_url(firings_file)
        if download_needed:
            print('Downloading file: ' + firings_file)
            self._firings_path = mt.realizeFile(path=firings_file)
            print('Done.')
        else:
            self._firings_path = mt.realizeFile(path=firings_file)
        if not self._firings_path:
            raise Exception('Unable to realize firings file: ' + firings_file)

        self._firings = readmda(self._firings_path)
        self._times = self._firings[1, :]
        self._labels = self._firings[2, :]
        self._unit_ids = np.unique(self._labels).astype(int)
Esempio n. 5
0
def main():
    parser = argparse.ArgumentParser(
        description=
        'Generate unit detail data (including spikesprays) for website')
    parser.add_argument(
        '--output_ids',
        help=
        'Comma-separated list of IDs of the analysis outputs to include in the website.',
        required=False,
        default=None)

    args = parser.parse_args()

    use_slurm = True

    mt.configDownloadFrom(['spikeforest.kbucket'])

    if args.output_ids is not None:
        output_ids = args.output_ids.split(',')
    else:
        output_ids = [
            'paired_boyden32c',
            'paired_crcns',
            'paired_mea64c',
            'paired_kampff',
            'paired_monotrode',
            'synth_monotrode',
            # 'synth_bionet',
            'synth_magland',
            'manual_franklab',
            'synth_mearec_neuronexus',
            'synth_mearec_tetrode',
            'synth_visapy',
            'hybrid_janelia'
        ]
    print('Using output ids: ', output_ids)

    print(
        '******************************** LOADING ANALYSIS OUTPUT OBJECTS...')
    for output_id in output_ids:
        slurm_working_dir = 'tmp_slurm_job_handler_' + _random_string(5)
        job_handler = mlpr.SlurmJobHandler(working_dir=slurm_working_dir)
        if use_slurm:
            job_handler.addBatchType(name='default',
                                     num_workers_per_batch=20,
                                     num_cores_per_job=1,
                                     time_limit_per_batch=1800,
                                     use_slurm=True,
                                     additional_srun_opts=['-p ccm'])
        else:
            job_handler.addBatchType(
                name='default',
                num_workers_per_batch=multiprocessing.cpu_count(),
                num_cores_per_job=1,
                use_slurm=False)
        with mlpr.JobQueue(job_handler=job_handler) as JQ:
            print('=============================================', output_id)
            print('Loading output object: {}'.format(output_id))
            output_path = (
                'key://pairio/spikeforest/spikeforest_analysis_results.{}.json'
            ).format(output_id)
            obj = mt.loadObject(path=output_path)
            # studies = obj['studies']
            # study_sets = obj.get('study_sets', [])
            # recordings = obj['recordings']
            sorting_results = obj['sorting_results']

            print(
                'Determining sorting results to process ({} total)...'.format(
                    len(sorting_results)))
            sorting_results_to_process = []
            for sr in sorting_results:
                key = dict(name='unit-details-v0.1.0',
                           recording_directory=sr['recording']['directory'],
                           firings_true=sr['firings_true'],
                           firings=sr['firings'])
                val = mt.getValue(key=key, collection='spikeforest')
                if not val:
                    sr['key'] = key
                    sorting_results_to_process.append(sr)

            print('Need to process {} of {} sorting results'.format(
                len(sorting_results_to_process), len(sorting_results)))

            recording_directories_to_process = sorted(
                list(
                    set([
                        sr['recording']['directory']
                        for sr in sorting_results_to_process
                    ])))
            print('{} recording directories to process'.format(
                len(recording_directories_to_process)))

            print('Filtering recordings...')
            filter_jobs = FilterTimeseries.createJobs([
                dict(recording_directory=recdir,
                     timeseries_out={'ext': '.mda'},
                     _timeout=600)
                for recdir in recording_directories_to_process
            ])
            filter_results = [job.execute() for job in filter_jobs]

            JQ.wait()

            filtered_timeseries_by_recdir = dict()
            for i, recdir in enumerate(recording_directories_to_process):
                result0 = filter_results[i]
                if result0.retcode != 0:
                    raise Exception(
                        'Problem computing filtered timeseries for recording: {}'
                        .format(recdir))
                filtered_timeseries_by_recdir[recdir] = result0.outputs[
                    'timeseries_out']

            print('Creating spike sprays...')
            for sr in sorting_results_to_process:
                rec = sr['recording']
                study_name = rec['study']
                rec_name = rec['name']
                sorter_name = sr['sorter']['name']

                print('====== COMPUTING {}/{}/{}'.format(
                    study_name, rec_name, sorter_name))

                if sr.get('comparison_with_truth', None) is not None:
                    cwt = mt.loadObject(
                        path=sr['comparison_with_truth']['json'])

                    filtered_timeseries = filtered_timeseries_by_recdir[
                        rec['directory']]

                    spike_spray_job_objects = []
                    list0 = list(cwt.values())
                    for _, unit in enumerate(list0):
                        # print('')
                        # print('=========================== {}/{}/{} unit {} of {}'.format(study_name, rec_name, sorter_name, ii + 1, len(list0)))
                        # ssobj = create_spikesprays(rx=rx, sx_true=sx_true, sx_sorted=sx, neighborhood_size=neighborhood_size, num_spikes=num_spikes, unit_id_true=unit['unit_id'], unit_id_sorted=unit['best_unit'])

                        spike_spray_job_objects.append(
                            dict(args=dict(
                                recording_directory=rec['directory'],
                                filtered_timeseries=filtered_timeseries,
                                firings_true=os.path.join(
                                    rec['directory'], 'firings_true.mda'),
                                firings_sorted=sr['firings'],
                                unit_id_true=unit['unit_id'],
                                unit_id_sorted=unit['best_unit'],
                                json_out={'ext': '.json'},
                                _container='default',
                                _timeout=180),
                                 study_name=study_name,
                                 rec_name=rec_name,
                                 sorter_name=sorter_name,
                                 unit=unit))
                    spike_spray_jobs = CreateSpikeSprays.createJobs(
                        [obj['args'] for obj in spike_spray_job_objects])
                    spike_spray_results = [
                        job.execute() for job in spike_spray_jobs
                    ]

                    sr['spike_spray_job_objects'] = spike_spray_job_objects
                    sr['spike_spray_results'] = spike_spray_results

            JQ.wait()

        for sr in sorting_results_to_process:
            rec = sr['recording']
            study_name = rec['study']
            rec_name = rec['name']
            sorter_name = sr['sorter']['name']

            print('====== SAVING {}/{}/{}'.format(study_name, rec_name,
                                                  sorter_name))

            if sr.get('comparison_with_truth', None) is not None:
                spike_spray_job_objects = sr['spike_spray_job_objects']
                spike_spray_results = sr['spike_spray_results']

                unit_details = []
                ok = True
                for i, result in enumerate(spike_spray_results):
                    obj0 = spike_spray_job_objects[i]
                    if result.retcode != 0:
                        print('WARNING: Error creating spike sprays for job:')
                        print(spike_spray_job_objects[i])
                        ok = False
                        break
                    ssobj = mt.loadObject(path=result.outputs['json_out'])
                    if ssobj is None:
                        raise Exception(
                            'Problem loading spikespray object output.')
                    address = mt.saveObject(object=ssobj,
                                            upload_to='spikeforest.kbucket')
                    unit = obj0['unit']
                    unit_details.append(
                        dict(studyName=obj0['study_name'],
                             recordingName=obj0['rec_name'],
                             sorterName=obj0['sorter_name'],
                             trueUnitId=unit['unit_id'],
                             sortedUnitId=unit['best_unit'],
                             spikeSprayUrl=mt.findFile(
                                 path=address,
                                 remote_only=True,
                                 download_from='spikeforest.kbucket'),
                             _container='default'))
                if ok:
                    mt.saveObject(collection='spikeforest',
                                  key=sr['key'],
                                  object=unit_details,
                                  upload_to='spikeforest.public')
def main():
    from mountaintools import client as mt

    parser = argparse.ArgumentParser(
        description=
        'Generate unit detail data (including spikesprays) for website')
    parser.add_argument('analysis_path',
                        help='assembled analysis file (output.json)')
    parser.add_argument(
        '--studysets',
        help='Comma-separated list of study set names to include',
        required=False,
        default=None)
    parser.add_argument('--force-run',
                        help='Force rerunning of processing',
                        action='store_true')
    parser.add_argument(
        '--force-run-all',
        help='Force rerunning of processing including filtering',
        action='store_true')
    parser.add_argument('--parallel',
                        help='Optional number of parallel jobs',
                        required=False,
                        default='0')
    parser.add_argument('--slurm',
                        help='Path to slurm config file',
                        required=False,
                        default=None)
    parser.add_argument('--cache',
                        help='The cache database to use',
                        required=False,
                        default=None)
    parser.add_argument('--job-timeout',
                        help='Timeout for processing jobs',
                        required=False,
                        default=600)
    parser.add_argument('--log-file',
                        help='Log file for analysis progress',
                        required=False,
                        default=None)
    parser.add_argument(
        '--force-regenerate',
        help=
        'Whether to force regenerating spike sprays (for when code has changed)',
        action='store_true')
    parser.add_argument('--test',
                        help='Whether to just test by running only 1',
                        action='store_true')

    args = parser.parse_args()

    mt.configDownloadFrom(['spikeforest.kbucket'])

    with open(args.analysis_path, 'r') as f:
        analysis = json.load(f)

    if args.studysets is not None:
        studyset_names = args.studysets.split(',')
        print('Using study sets: ', studyset_names)
    else:
        studyset_names = None

    study_sets = analysis['StudySets']
    sorting_results = analysis['SortingResults']

    studies_to_include = []
    for ss in study_sets:
        if (studyset_names is None) or (ss['name'] in studyset_names):
            for study in ss['studies']:
                studies_to_include.append(study['name'])

    print('Including studies:', studies_to_include)

    print('Determining sorting results to process ({} total)...'.format(
        len(sorting_results)))
    sorting_results_to_process = []
    sorting_results_to_consider = []
    for sr in sorting_results:
        study_name = sr['studyName']
        if study_name in studies_to_include:
            if 'firings' in sr:
                if sr.get('comparisonWithTruth', None) is not None:
                    sorting_results_to_consider.append(sr)
                    key = dict(name='unit-details-v0.1.0',
                               recording_directory=sr['recordingDirectory'],
                               firings_true=sr['firingsTrue'],
                               firings=sr['firings'])
                    val = mt.getValue(key=key, collection='spikeforest')
                    if (not val) or (args.force_regenerate):
                        sr['key'] = key
                        sorting_results_to_process.append(sr)
    if args.test and len(sorting_results_to_process) > 0:
        sorting_results_to_process = [sorting_results_to_process[0]]

    print('Need to process {} of {} sorting results'.format(
        len(sorting_results_to_process), len(sorting_results_to_consider)))

    recording_directories_to_process = sorted(
        list(
            set([
                sr['recordingDirectory'] for sr in sorting_results_to_process
            ])))
    print('{} recording directories to process'.format(
        len(recording_directories_to_process)))

    if int(args.parallel) > 0:
        job_handler = hither.ParallelJobHandler(int(args.parallel))
    elif args.slurm:
        with open(args.slurm, 'r') as f:
            slurm_config = json.load(f)
        job_handler = hither.SlurmJobHandler(working_dir='tmp_slurm',
                                             **slurm_config['cpu'])
    else:
        job_handler = None

    print('Filtering recordings...')
    filter_results = []
    with hither.config(container='default',
                       cache=args.cache,
                       force_run=args.force_run_all,
                       job_handler=job_handler,
                       log_path=args.log_file,
                       exception_on_fail=True,
                       cache_failing=False,
                       rerun_failing=True,
                       job_timeout=args.job_timeout), hither.job_queue():
        for recdir in recording_directories_to_process:
            result = filter_recording.run(recording_directory=recdir,
                                          timeseries_out=hither.File())
            filter_results.append(result)
    filtered_timeseries_by_recdir = dict()
    for i, recdir in enumerate(recording_directories_to_process):
        result0 = filter_results[i]
        if not result0.success:
            raise Exception(
                'Problem computing filtered timeseries for recording: {}'.
                format(recdir))
        filtered_timeseries_by_recdir[
            recdir] = result0.outputs.timeseries_out._path

    print('Creating spike sprays...')
    with hither.config(container='default',
                       cache=args.cache,
                       force_run=args.force_run or args.force_run_all,
                       job_handler=job_handler,
                       log_path=args.log_file,
                       exception_on_fail=True,
                       cache_failing=False,
                       rerun_failing=True,
                       job_timeout=args.job_timeout), hither.job_queue():
        for sr in sorting_results_to_process:
            recdir = sr['recordingDirectory']
            study_name = sr['studyName']
            rec_name = sr['recordingName']
            sorter_name = sr['sorterName']

            print('====== COMPUTING {}/{}/{}'.format(study_name, rec_name,
                                                     sorter_name))

            cwt = ka.load_object(path=sr['comparisonWithTruth']['json'])

            filtered_timeseries = filtered_timeseries_by_recdir[recdir]

            spike_spray_results = []
            list0 = list(cwt.values())
            for _, unit in enumerate(list0):
                result = create_spike_sprays.run(
                    recording_directory=recdir,
                    filtered_timeseries=filtered_timeseries,
                    firings_true=os.path.join(recdir, 'firings_true.mda'),
                    firings_sorted=sr['firings'],
                    unit_id_true=unit['unit_id'],
                    unit_id_sorted=unit['best_unit'],
                    json_out=hither.File())
                setattr(result, 'unit', unit)
                spike_spray_results.append(result)
            sr['spike_spray_results'] = spike_spray_results

    for sr in sorting_results_to_process:
        recdir = sr['recordingDirectory']
        study_name = sr['studyName']
        rec_name = sr['recordingName']
        sorter_name = sr['sorterName']

        print('====== SAVING {}/{}/{}'.format(study_name, rec_name,
                                              sorter_name))
        spike_spray_results = sr['spike_spray_results']
        key = sr['key']

        unit_details = []
        ok = True
        for i, result in enumerate(spike_spray_results):
            if not result.success:
                print(
                    'WARNING: Error creating spike sprays for {}/{}/{}'.format(
                        study_name, rec_name, sorter_name))
                ok = False
                break
            ssobj = ka.load_object(result.outputs.json_out._path)
            if ssobj is None:
                raise Exception('Problem loading spikespray object output.')
            address = mt.saveObject(object=ssobj,
                                    upload_to='spikeforest.kbucket')
            unit = getattr(result, 'unit')
            unit_details.append(
                dict(studyName=study_name,
                     recordingName=rec_name,
                     sorterName=sorter_name,
                     trueUnitId=unit['unit_id'],
                     sortedUnitId=unit['best_unit'],
                     spikeSprayUrl=mt.findFile(
                         path=address,
                         remote_only=True,
                         download_from='spikeforest.kbucket')))

        if ok:
            mt.saveObject(collection='spikeforest',
                          key=key,
                          object=unit_details,
                          upload_to='spikeforest.public')
Esempio n. 7
0
 def firingsTrueFileIsLocal(self):
     fname = self.directory() + '/firings_true.mda'
     fname2 = mt.findFile(fname, local_only=True)
     if fname2 and (not _is_url(fname2)):
         return True
     return False
Esempio n. 8
0
 def recordingFileIsLocal(self):
     fname = self.directory() + '/raw.mda'
     fname2 = mt.findFile(fname, local_only=True)
     if fname2 and (not _is_url(fname2)):
         return True
     return False