def main():

    parser = argparse.ArgumentParser(
        description='update spike-front analysis history')
    parser.add_argument(
        '--delete-and-init',
        help='Use this only the first time. Will wipe out the history.',
        action='store_true')
    parser.add_argument('--show',
                        help='Show the analysis history.',
                        action='store_true')

    args = parser.parse_args()

    history_path = 'key://pairio/spikeforest/spike-front-analysis-history.json'
    results_path = 'key://pairio/spikeforest/spike-front-results.json'

    mt.configDownloadFrom(['spikeforest.kbucket', 'spikeforest.public'])
    if args.delete_and_init:
        print('Initializing analysis history...')
        mt.createSnapshot(mt.saveObject(object=dict(analyses=[])),
                          dest_path=history_path,
                          upload_to='spikeforest.public')
        print('Done.')
        return

    print('Loading history...')
    history = mt.loadObject(path=history_path)
    assert history, 'Unable to load history'

    if args.show:
        for aa in history['analyses']:
            print(
                '==============================================================='
            )
            print('ANALYSIS: {}'.format(aa['General']['dateUpdated']))
            print('PATH: {}'.format(aa['path']))
            print(json.dumps(aa['General'], indent=4))
            print('')
        return

    print('Loading current results...')
    spike_front_results = mt.loadObject(path=results_path)
    assert spike_front_results, 'Unable to load current results'

    sha1_url = mt.saveObject(object=spike_front_results,
                             basename='analysis.json')
    for aa in history['analyses']:
        if aa['path'] == sha1_url:
            print('Analysis already stored in history.')
            return

    history['analyses'].append(
        dict(path=sha1_url, General=spike_front_results['General'][0]))

    print('Saving updated history to {}'.format(history_path))
    mt.createSnapshot(mt.saveObject(object=history),
                      dest_path=history_path,
                      upload_to='spikeforest.public')
    print('Done.')
Exemplo n.º 2
0
def main():
    job = RepeatText.createJob(textfile=mlpr.PLACEHOLDER,
                               textfile_out=dict(ext='.txt'),
                               num_repeats=mlpr.PLACEHOLDER)
    mt.saveObject(object=job.getObject(),
                  dest_path='repeat_text.json',
                  indent=4)

    job = sa.ComputeRecordingInfo.createJob(recording_dir=mlpr.PLACEHOLDER,
                                            channels=[],
                                            json_out={'ext': '.json'},
                                            _container='default')
    mt.saveObject(object=job.getObject(),
                  dest_path='ComputeRecordingInfo.json',
                  indent=4)
Exemplo n.º 3
0
def assembleBatchResults(*, batch_name):
    batch = mt.loadObject(key=dict(batch_name=batch_name))
    jobs = batch['jobs']

    print('Assembling results for batch {} with {} jobs'.format(
        batch_name, len(jobs)))
    job_results = []
    for job in jobs:
        print('ASSEMBLING: ' + job['label'])
        result = mt.loadObject(key=job)
        if not result:
            raise Exception('Unable to load object for job: ' + job['label'])
        job_results.append(dict(job=job, result=result))
    print('Saving results...')
    mt.saveObject(key=dict(name='job_results', batch_name=batch_name),
                  object=dict(job_results=job_results))
    print('Done.')
Exemplo n.º 4
0
def install_waveclus(repo, commit):
    spikeforest_alg_install_path = get_install_path()
    key = dict(
        alg='waveclus',
        repo=repo,
        commit=commit
    )
    source_path = spikeforest_alg_install_path + '/waveclus_' + commit
    if os.path.exists(source_path):
        # The dir hash method does not seem to be working for some reason here
        # hash0 = mt.computeDirHash(source_path)
        # if hash0 == mt.getValue(key=key):
        #     print('waveclus is already auto-installed.')
        #     return source_path

        a = mt.loadObject(path=source_path + '/spikeforest.json')
        if a:
            if mt.sha1OfObject(a) == mt.sha1OfObject(key):
                print('waveclus is already auto-installed.')
                return source_path

        print('Removing directory: {}'.format(source_path))
        shutil.rmtree(source_path)

    script = """
    #!/bin/bash
    set -e

    git clone {repo} {source_path}
    cd {source_path}
    git checkout {commit}
    """.format(repo=repo, commit=commit, source_path=source_path)
    ss = mlpr.ShellScript(script=script)
    ss.start()
    retcode = ss.wait()
    if retcode != 0:
        raise Exception('Install script returned a non-zero exit code/')

    # The dir hash method does not seem to be working for some reason here
    # hash0 = mt.computeDirHash(source_path)
    # mt.setValue(key=key, value=hash0)
    mt.saveObject(object=key, dest_path=source_path + '/spikeforest.json')

    return source_path
Exemplo n.º 5
0
def _run_job(job):
    val = mt.getValue(key=job)
    if val:
        return
    code = ''.join(random.choice(string.ascii_uppercase) for x in range(10))
    if not mt.setValue(key=job, value='in-process-' + code, overwrite=False):
        return
    status = dict(time_started=_make_timestamp(), status='running')
    _set_job_status(job, status)

    print('Running job: ' + job['label'])
    try:
        result = _do_run_job(job)
    except:
        status['time_finished'] = _make_timestamp()
        status['status'] = 'error'
        status['error'] = 'Exception in _do_run_job'
        val = mt.getValue(key=job)
        if val == 'in-process-' + code:
            _set_job_status(job, status)
        raise

    val = mt.getValue(key=job)
    if val != 'in-process-' + code:
        print(
            'Not saving result because in-process code does not match {} <> {}.'
            .format(val, 'in-process-' + code))
        return

    status['time_finished'] = _make_timestamp()
    status['result'] = result
    if 'error' in result:
        print('Error running job: ' + result['error'])
        status['status'] = 'error'
        status['error'] = result['error']
        _set_job_status(job, status)
        mt.setValue(key=job, value='error-' + code)
        return
    status['status'] = 'finished'
    mt.saveObject(
        key=job, object=result
    )  # Not needed in future, because we should instead use the status object
def main():
    path = mt.createSnapshot(path='recordings_out')
    mt.configDownloadFrom('spikeforest.public')
    X = mt.readDir(path)
    for study_set_name, d in X['dirs'].items():
        study_sets = []
        studies = []
        recordings = []
        study_sets.append(dict(
            name=study_set_name + '_b',
            info=dict(),
            description=''
        ))
        for study_name, d2 in d['dirs'].items():
            study_dir = path + '/' + study_set_name + '/' + study_name
            study0 = dict(
                name=study_name,
                study_set=study_set_name + '_b',
                directory=study_dir,
                description=''
            )
            studies.append(study0)
            index_within_study = 0
            for recording_name, d3 in d2['dirs'].items():
                recdir = study_dir + '/' + recording_name
                recordings.append(dict(
                    name=recording_name,
                    study=study_name,
                    directory=recdir,
                    firings_true=recdir + '/firings_true.mda',
                    index_within_study=index_within_study,
                    description='One of the recordings in the {} study'.format(study_name)
                ))
                index_within_study = index_within_study + 1

        print('Saving object...')
        group_name = study_set_name + '_b'
        address = mt.saveObject(
            object=dict(
                studies=studies,
                recordings=recordings,
                study_sets=study_sets
            ),
            key=dict(name='spikeforest_recording_group', group_name=group_name),
            upload_to='spikeforest.public'
        )
        if not address:
            raise Exception('Problem uploading object to kachery')

        output_fname = 'key://pairio/spikeforest/spikeforest_recording_group.{}.json'.format(group_name)
        print('Saving output to {}'.format(output_fname))
        mt.createSnapshot(path=address, dest_path=output_fname)
Exemplo n.º 7
0
def do_prepare(recording_group, study_name):
    print(recording_group, study_name)
    X = mt.loadObject(
        path="key://pairio/spikeforest/spikeforest_recording_group.{}.json".
        format(recording_group))
    studies = [y for y in X['studies'] if (y['name'] == study_name)]
    recordings = [y for y in X['recordings'] if y['study'] == study_name]
    recordings = recordings[0:1]
    study_sets = X['study_sets']

    Y = dict(studies=studies, recordings=recordings, study_sets=study_sets)
    address = mt.saveObject(object=Y)
    assert address is not None
    dest_path = 'key://pairio/spikeforest/spikeforest_recording_group.test_{}.json'.format(
        recording_group)
    print(dest_path)
    mt.createSnapshot(path=address,
                      upload_to='spikeforest.kbucket',
                      dest_path=dest_path)
Exemplo n.º 8
0
def main():
    parser = argparse.ArgumentParser(
        description=
        'Generate unit detail data (including spikesprays) for website')
    parser.add_argument(
        '--output_ids',
        help=
        'Comma-separated list of IDs of the analysis outputs to include in the website.',
        required=False,
        default=None)

    args = parser.parse_args()

    use_slurm = True

    mt.configDownloadFrom(['spikeforest.kbucket'])

    if args.output_ids is not None:
        output_ids = args.output_ids.split(',')
    else:
        output_ids = [
            'paired_boyden32c',
            'paired_crcns',
            'paired_mea64c',
            'paired_kampff',
            'paired_monotrode',
            'synth_monotrode',
            # 'synth_bionet',
            'synth_magland',
            'manual_franklab',
            'synth_mearec_neuronexus',
            'synth_mearec_tetrode',
            'synth_visapy',
            'hybrid_janelia'
        ]
    print('Using output ids: ', output_ids)

    print(
        '******************************** LOADING ANALYSIS OUTPUT OBJECTS...')
    for output_id in output_ids:
        slurm_working_dir = 'tmp_slurm_job_handler_' + _random_string(5)
        job_handler = mlpr.SlurmJobHandler(working_dir=slurm_working_dir)
        if use_slurm:
            job_handler.addBatchType(name='default',
                                     num_workers_per_batch=20,
                                     num_cores_per_job=1,
                                     time_limit_per_batch=1800,
                                     use_slurm=True,
                                     additional_srun_opts=['-p ccm'])
        else:
            job_handler.addBatchType(
                name='default',
                num_workers_per_batch=multiprocessing.cpu_count(),
                num_cores_per_job=1,
                use_slurm=False)
        with mlpr.JobQueue(job_handler=job_handler) as JQ:
            print('=============================================', output_id)
            print('Loading output object: {}'.format(output_id))
            output_path = (
                'key://pairio/spikeforest/spikeforest_analysis_results.{}.json'
            ).format(output_id)
            obj = mt.loadObject(path=output_path)
            # studies = obj['studies']
            # study_sets = obj.get('study_sets', [])
            # recordings = obj['recordings']
            sorting_results = obj['sorting_results']

            print(
                'Determining sorting results to process ({} total)...'.format(
                    len(sorting_results)))
            sorting_results_to_process = []
            for sr in sorting_results:
                key = dict(name='unit-details-v0.1.0',
                           recording_directory=sr['recording']['directory'],
                           firings_true=sr['firings_true'],
                           firings=sr['firings'])
                val = mt.getValue(key=key, collection='spikeforest')
                if not val:
                    sr['key'] = key
                    sorting_results_to_process.append(sr)

            print('Need to process {} of {} sorting results'.format(
                len(sorting_results_to_process), len(sorting_results)))

            recording_directories_to_process = sorted(
                list(
                    set([
                        sr['recording']['directory']
                        for sr in sorting_results_to_process
                    ])))
            print('{} recording directories to process'.format(
                len(recording_directories_to_process)))

            print('Filtering recordings...')
            filter_jobs = FilterTimeseries.createJobs([
                dict(recording_directory=recdir,
                     timeseries_out={'ext': '.mda'},
                     _timeout=600)
                for recdir in recording_directories_to_process
            ])
            filter_results = [job.execute() for job in filter_jobs]

            JQ.wait()

            filtered_timeseries_by_recdir = dict()
            for i, recdir in enumerate(recording_directories_to_process):
                result0 = filter_results[i]
                if result0.retcode != 0:
                    raise Exception(
                        'Problem computing filtered timeseries for recording: {}'
                        .format(recdir))
                filtered_timeseries_by_recdir[recdir] = result0.outputs[
                    'timeseries_out']

            print('Creating spike sprays...')
            for sr in sorting_results_to_process:
                rec = sr['recording']
                study_name = rec['study']
                rec_name = rec['name']
                sorter_name = sr['sorter']['name']

                print('====== COMPUTING {}/{}/{}'.format(
                    study_name, rec_name, sorter_name))

                if sr.get('comparison_with_truth', None) is not None:
                    cwt = mt.loadObject(
                        path=sr['comparison_with_truth']['json'])

                    filtered_timeseries = filtered_timeseries_by_recdir[
                        rec['directory']]

                    spike_spray_job_objects = []
                    list0 = list(cwt.values())
                    for _, unit in enumerate(list0):
                        # print('')
                        # print('=========================== {}/{}/{} unit {} of {}'.format(study_name, rec_name, sorter_name, ii + 1, len(list0)))
                        # ssobj = create_spikesprays(rx=rx, sx_true=sx_true, sx_sorted=sx, neighborhood_size=neighborhood_size, num_spikes=num_spikes, unit_id_true=unit['unit_id'], unit_id_sorted=unit['best_unit'])

                        spike_spray_job_objects.append(
                            dict(args=dict(
                                recording_directory=rec['directory'],
                                filtered_timeseries=filtered_timeseries,
                                firings_true=os.path.join(
                                    rec['directory'], 'firings_true.mda'),
                                firings_sorted=sr['firings'],
                                unit_id_true=unit['unit_id'],
                                unit_id_sorted=unit['best_unit'],
                                json_out={'ext': '.json'},
                                _container='default',
                                _timeout=180),
                                 study_name=study_name,
                                 rec_name=rec_name,
                                 sorter_name=sorter_name,
                                 unit=unit))
                    spike_spray_jobs = CreateSpikeSprays.createJobs(
                        [obj['args'] for obj in spike_spray_job_objects])
                    spike_spray_results = [
                        job.execute() for job in spike_spray_jobs
                    ]

                    sr['spike_spray_job_objects'] = spike_spray_job_objects
                    sr['spike_spray_results'] = spike_spray_results

            JQ.wait()

        for sr in sorting_results_to_process:
            rec = sr['recording']
            study_name = rec['study']
            rec_name = rec['name']
            sorter_name = sr['sorter']['name']

            print('====== SAVING {}/{}/{}'.format(study_name, rec_name,
                                                  sorter_name))

            if sr.get('comparison_with_truth', None) is not None:
                spike_spray_job_objects = sr['spike_spray_job_objects']
                spike_spray_results = sr['spike_spray_results']

                unit_details = []
                ok = True
                for i, result in enumerate(spike_spray_results):
                    obj0 = spike_spray_job_objects[i]
                    if result.retcode != 0:
                        print('WARNING: Error creating spike sprays for job:')
                        print(spike_spray_job_objects[i])
                        ok = False
                        break
                    ssobj = mt.loadObject(path=result.outputs['json_out'])
                    if ssobj is None:
                        raise Exception(
                            'Problem loading spikespray object output.')
                    address = mt.saveObject(object=ssobj,
                                            upload_to='spikeforest.kbucket')
                    unit = obj0['unit']
                    unit_details.append(
                        dict(studyName=obj0['study_name'],
                             recordingName=obj0['rec_name'],
                             sorterName=obj0['sorter_name'],
                             trueUnitId=unit['unit_id'],
                             sortedUnitId=unit['best_unit'],
                             spikeSprayUrl=mt.findFile(
                                 path=address,
                                 remote_only=True,
                                 download_from='spikeforest.kbucket'),
                             _container='default'))
                if ok:
                    mt.saveObject(collection='spikeforest',
                                  key=sr['key'],
                                  object=unit_details,
                                  upload_to='spikeforest.public')
Exemplo n.º 9
0
def main():
    from mountaintools import client as mt

    parser = argparse.ArgumentParser(
        description=
        'Generate unit detail data (including spikesprays) for website')
    parser.add_argument('analysis_path',
                        help='assembled analysis file (output.json)')
    parser.add_argument(
        '--studysets',
        help='Comma-separated list of study set names to include',
        required=False,
        default=None)
    parser.add_argument('--force-run',
                        help='Force rerunning of processing',
                        action='store_true')
    parser.add_argument(
        '--force-run-all',
        help='Force rerunning of processing including filtering',
        action='store_true')
    parser.add_argument('--parallel',
                        help='Optional number of parallel jobs',
                        required=False,
                        default='0')
    parser.add_argument('--slurm',
                        help='Path to slurm config file',
                        required=False,
                        default=None)
    parser.add_argument('--cache',
                        help='The cache database to use',
                        required=False,
                        default=None)
    parser.add_argument('--job-timeout',
                        help='Timeout for processing jobs',
                        required=False,
                        default=600)
    parser.add_argument('--log-file',
                        help='Log file for analysis progress',
                        required=False,
                        default=None)
    parser.add_argument(
        '--force-regenerate',
        help=
        'Whether to force regenerating spike sprays (for when code has changed)',
        action='store_true')
    parser.add_argument('--test',
                        help='Whether to just test by running only 1',
                        action='store_true')

    args = parser.parse_args()

    mt.configDownloadFrom(['spikeforest.kbucket'])

    with open(args.analysis_path, 'r') as f:
        analysis = json.load(f)

    if args.studysets is not None:
        studyset_names = args.studysets.split(',')
        print('Using study sets: ', studyset_names)
    else:
        studyset_names = None

    study_sets = analysis['StudySets']
    sorting_results = analysis['SortingResults']

    studies_to_include = []
    for ss in study_sets:
        if (studyset_names is None) or (ss['name'] in studyset_names):
            for study in ss['studies']:
                studies_to_include.append(study['name'])

    print('Including studies:', studies_to_include)

    print('Determining sorting results to process ({} total)...'.format(
        len(sorting_results)))
    sorting_results_to_process = []
    sorting_results_to_consider = []
    for sr in sorting_results:
        study_name = sr['studyName']
        if study_name in studies_to_include:
            if 'firings' in sr:
                if sr.get('comparisonWithTruth', None) is not None:
                    sorting_results_to_consider.append(sr)
                    key = dict(name='unit-details-v0.1.0',
                               recording_directory=sr['recordingDirectory'],
                               firings_true=sr['firingsTrue'],
                               firings=sr['firings'])
                    val = mt.getValue(key=key, collection='spikeforest')
                    if (not val) or (args.force_regenerate):
                        sr['key'] = key
                        sorting_results_to_process.append(sr)
    if args.test and len(sorting_results_to_process) > 0:
        sorting_results_to_process = [sorting_results_to_process[0]]

    print('Need to process {} of {} sorting results'.format(
        len(sorting_results_to_process), len(sorting_results_to_consider)))

    recording_directories_to_process = sorted(
        list(
            set([
                sr['recordingDirectory'] for sr in sorting_results_to_process
            ])))
    print('{} recording directories to process'.format(
        len(recording_directories_to_process)))

    if int(args.parallel) > 0:
        job_handler = hither.ParallelJobHandler(int(args.parallel))
    elif args.slurm:
        with open(args.slurm, 'r') as f:
            slurm_config = json.load(f)
        job_handler = hither.SlurmJobHandler(working_dir='tmp_slurm',
                                             **slurm_config['cpu'])
    else:
        job_handler = None

    print('Filtering recordings...')
    filter_results = []
    with hither.config(container='default',
                       cache=args.cache,
                       force_run=args.force_run_all,
                       job_handler=job_handler,
                       log_path=args.log_file,
                       exception_on_fail=True,
                       cache_failing=False,
                       rerun_failing=True,
                       job_timeout=args.job_timeout), hither.job_queue():
        for recdir in recording_directories_to_process:
            result = filter_recording.run(recording_directory=recdir,
                                          timeseries_out=hither.File())
            filter_results.append(result)
    filtered_timeseries_by_recdir = dict()
    for i, recdir in enumerate(recording_directories_to_process):
        result0 = filter_results[i]
        if not result0.success:
            raise Exception(
                'Problem computing filtered timeseries for recording: {}'.
                format(recdir))
        filtered_timeseries_by_recdir[
            recdir] = result0.outputs.timeseries_out._path

    print('Creating spike sprays...')
    with hither.config(container='default',
                       cache=args.cache,
                       force_run=args.force_run or args.force_run_all,
                       job_handler=job_handler,
                       log_path=args.log_file,
                       exception_on_fail=True,
                       cache_failing=False,
                       rerun_failing=True,
                       job_timeout=args.job_timeout), hither.job_queue():
        for sr in sorting_results_to_process:
            recdir = sr['recordingDirectory']
            study_name = sr['studyName']
            rec_name = sr['recordingName']
            sorter_name = sr['sorterName']

            print('====== COMPUTING {}/{}/{}'.format(study_name, rec_name,
                                                     sorter_name))

            cwt = ka.load_object(path=sr['comparisonWithTruth']['json'])

            filtered_timeseries = filtered_timeseries_by_recdir[recdir]

            spike_spray_results = []
            list0 = list(cwt.values())
            for _, unit in enumerate(list0):
                result = create_spike_sprays.run(
                    recording_directory=recdir,
                    filtered_timeseries=filtered_timeseries,
                    firings_true=os.path.join(recdir, 'firings_true.mda'),
                    firings_sorted=sr['firings'],
                    unit_id_true=unit['unit_id'],
                    unit_id_sorted=unit['best_unit'],
                    json_out=hither.File())
                setattr(result, 'unit', unit)
                spike_spray_results.append(result)
            sr['spike_spray_results'] = spike_spray_results

    for sr in sorting_results_to_process:
        recdir = sr['recordingDirectory']
        study_name = sr['studyName']
        rec_name = sr['recordingName']
        sorter_name = sr['sorterName']

        print('====== SAVING {}/{}/{}'.format(study_name, rec_name,
                                              sorter_name))
        spike_spray_results = sr['spike_spray_results']
        key = sr['key']

        unit_details = []
        ok = True
        for i, result in enumerate(spike_spray_results):
            if not result.success:
                print(
                    'WARNING: Error creating spike sprays for {}/{}/{}'.format(
                        study_name, rec_name, sorter_name))
                ok = False
                break
            ssobj = ka.load_object(result.outputs.json_out._path)
            if ssobj is None:
                raise Exception('Problem loading spikespray object output.')
            address = mt.saveObject(object=ssobj,
                                    upload_to='spikeforest.kbucket')
            unit = getattr(result, 'unit')
            unit_details.append(
                dict(studyName=study_name,
                     recordingName=rec_name,
                     sorterName=sorter_name,
                     trueUnitId=unit['unit_id'],
                     sortedUnitId=unit['best_unit'],
                     spikeSprayUrl=mt.findFile(
                         path=address,
                         remote_only=True,
                         download_from='spikeforest.kbucket')))

        if ok:
            mt.saveObject(collection='spikeforest',
                          key=key,
                          object=unit_details,
                          upload_to='spikeforest.public')
    return studies, recordings, study_sets


# Prepare the studies
studies, recordings, study_sets = prepare_manual_buzsaki_studies(
    basedir=basedir)

print('Uploading files to kachery...')
for rec in recordings:
    mt.createSnapshot(rec['directory'],
                      upload_to=upload_to,
                      upload_recursive=True)

print('Saving object...')
for ut in [upload_to]:
    address = mt.saveObject(object=dict(studies=studies,
                                        recordings=recordings,
                                        study_sets=study_sets),
                            key=dict(name='spikeforest_recording_group',
                                     group_name=group_name),
                            upload_to=ut)
    if not address:
        raise Exception('Problem uploading object to {}'.format(ut))

output_fname = 'key://pairio/spikeforest/spikeforest_recording_group.{}.json'.format(
    group_name)
print('Saving output to {}'.format(output_fname))
mt.createSnapshot(path=address, dest_path=output_fname)

print('Done.')
Exemplo n.º 11
0
def install_kilosort2(repo, commit):
    spikeforest_alg_install_path = get_install_path()
    key = dict(alg='kilosort2', repo=repo, commit=commit)
    source_path = spikeforest_alg_install_path + '/kilosort2_' + commit
    if os.path.exists(source_path):
        # The dir hash method does not seem to be working for some reason here
        # hash0 = mt.computeDirHash(source_path)
        # if hash0 == mt.getValue(key=key):
        #     print('Kilosort2 is already auto-installed.')
        #     return source_path

        a = mt.loadObject(path=source_path + '/spikeforest.json')
        if a:
            if mt.sha1OfObject(a) == mt.sha1OfObject(key):
                print('Kilosort2 is already auto-installed.')
                return source_path

        print('Removing directory: {}'.format(source_path))
        shutil.rmtree(source_path)

    script = """
    #!/bin/bash
    set -e

    git clone {repo} {source_path}
    cd {source_path}
    git checkout {commit}
    """.format(repo=repo, commit=commit, source_path=source_path)
    ss = mlpr.ShellScript(script=script)
    ss.start()
    retcode = ss.wait()
    if retcode != 0:
        raise Exception('Install script returned a non-zero exit code/')

    # make sure module unload gcc/7.4.0
    compile_gpu = mlpr.ShellScript(script="""
    function compile_gpu

    try
        [~,path_nvcc_] = system('which nvcc');
        path_nvcc_ = strrep(path_nvcc_, 'nvcc', '');
        disp(['path_nvcc_: ', path_nvcc_]);
        setenv('MW_NVCC_PATH', path_nvcc_);
        run('mexGPUall.m');
    catch
        disp('Problem running mexGPUall.');
        disp(lasterr());
        exit(-1)
    end;
    exit(0)
    """)
    compile_gpu.write(script_path=source_path + '/CUDA/compile_gpu.m')

    script = """
    #!/bin/bash
    set -e

    cd {source_path}/CUDA
    matlab -nodisplay -nosplash -r "compile_gpu"
    """.format(source_path=source_path)
    ss = mlpr.ShellScript(script=script)
    ss.start()
    retcode = ss.wait()
    if retcode != 0:
        raise Exception('Compute gpu script returned a non-zero exit code.')

    # The dir hash method does not seem to be working for some reason here
    # hash0 = mt.computeDirHash(source_path)
    # mt.setValue(key=key, value=hash0)
    mt.saveObject(object=key, dest_path=source_path + '/spikeforest.json')

    return source_path
def main():
    mt.configDownloadFrom('spikeforest.public')
    templates_path = 'sha1dir://95dba567b5168bacb480411ca334ffceb96b8c19.2019-06-11.templates'
    recordings_path = 'recordings_out'

    tempgen_tetrode = templates_path + '/templates_tetrode.h5'
    tempgen_neuronexus = templates_path + '/templates_neuronexus.h5'
    tempgen_neuropixels = templates_path + '/templates_neuropixels.h5'
    tempgen_neuronexus_drift = templates_path + '/templates_neuronexus_drift.h5'

    noise_level = [10, 20]
    duration = 600
    bursting = [False, True]
    nrec = 2  # change this to 10
    ei_ratio = 0.8
    rec_dict = {
        'tetrode': {
            'ncells': [10, 20],
            'tempgen': tempgen_tetrode,
            'drifting': False
        },
        'neuronexus': {
            'ncells': [10, 20, 40],
            'tempgen': tempgen_neuronexus,
            'drifting': False
        },
        'neuropixels': {
            'ncells': [20, 40, 60],
            'tempgen': tempgen_neuropixels,
            'drifting': False
        },
        'neuronexus_drift': {
            'ncells': [10, 20, 40],
            'tempgen': tempgen_neuronexus_drift,
            'drifting': True
        }
    }

    # optional: if drifting change drift velocity
    # recording_params['recordings']['drift_velocity] = ...

    # Generate and save recordings
    if os.path.exists(recordings_path):
        shutil.rmtree(recordings_path)
    os.mkdir(recordings_path)

    # Set up slurm configuration
    slurm_working_dir = 'tmp_slurm_job_handler_' + _random_string(5)
    job_handler = mlpr.SlurmJobHandler(working_dir=slurm_working_dir)
    use_slurm = True
    job_timeout = 3600 * 4
    if use_slurm:
        job_handler.addBatchType(name='default',
                                 num_workers_per_batch=4,
                                 num_cores_per_job=6,
                                 time_limit_per_batch=job_timeout * 3,
                                 use_slurm=True,
                                 max_simultaneous_batches=20,
                                 additional_srun_opts=['-p ccm'])
    else:
        job_handler.addBatchType(
            name='default',
            num_workers_per_batch=multiprocessing.cpu_count(),
            num_cores_per_job=2,
            max_simultaneous_batches=1,
            use_slurm=False)
    with mlpr.JobQueue(job_handler=job_handler) as JQ:
        results_to_write = []
        for rec_type in rec_dict.keys():
            study_set_name = 'SYNTH_MEAREC_{}'.format(rec_type.upper())
            os.mkdir(recordings_path + '/' + study_set_name)
            params = dict()
            params['duration'] = duration
            params['drifting'] = rec_dict[rec_type]['drifting']
            # reduce minimum distance for dense recordings
            params['min_dist'] = 15
            for ncells in rec_dict[rec_type]['ncells']:
                # changing number of cells
                n_exc = int(ei_ratio *
                            10)  # intentionally replaced nrec by 10 here
                params['n_exc'] = n_exc
                params['n_inh'] = ncells - n_exc
                for n in noise_level:
                    # changing noise level
                    params['noise_level'] = n
                    for b in bursting:
                        bursting_str = ''
                        if b:
                            bursting_str = '_bursting'
                        study_name = 'synth_mearec_{}_noise{}_K{}{}'.format(
                            rec_type, n, ncells, bursting_str)
                        os.mkdir(recordings_path + '/' + study_set_name + '/' +
                                 study_name)
                        for i in range(nrec):
                            # set random seeds
                            params[
                                'seed'] = i  # intentionally doing it this way

                            # changing bursting and shape modulation
                            print('Generating', rec_type, 'recording with',
                                  ncells, 'noise level', n, 'bursting', b)
                            params['bursting'] = b
                            params['shape_mod'] = b
                            templates0 = mt.realizeFile(
                                path=rec_dict[rec_type]['tempgen'])
                            result0 = GenerateMearecRecording.execute(
                                **params,
                                templates_in=templates0,
                                recording_out=dict(ext='.h5'))
                            mda_output_folder = recordings_path + '/' + study_set_name + '/' + study_name + '/' + '{}'.format(
                                i)
                            results_to_write.append(
                                dict(result=result0,
                                     mda_output_folder=mda_output_folder))
        JQ.wait()

        for x in results_to_write:
            result0: mlpr.MountainJobResult = x['result']
            mda_output_folder = x['mda_output_folder']
            path = mt.realizeFile(path=result0.outputs['recording_out'])
            recording = se.MEArecRecordingExtractor(recording_path=path)
            sorting_true = se.MEArecSortingExtractor(recording_path=path)
            se.MdaRecordingExtractor.write_recording(
                recording=recording, save_path=mda_output_folder)
            se.MdaSortingExtractor.write_sorting(sorting=sorting_true,
                                                 save_path=mda_output_folder +
                                                 '/firings_true.mda')
            if result0.console_out:
                mt.realizeFile(path=result0.console_out,
                               dest_path=mda_output_folder +
                               '.console_out.txt')
            if result0.runtime_info:
                mt.saveObject(object=result0.runtime_info,
                              dest_path=mda_output_folder +
                              '.runtime_info.json')

    print('Creating and uploading snapshot...')
    sha1dir_path = mt.createSnapshot(path=recordings_path,
                                     upload_to='spikeforest.public',
                                     upload_recursive=False)
    # sha1dir_path = mt.createSnapshot(path=recordings_path, upload_to='spikeforest.kbucket', upload_recursive=True)
    print(sha1dir_path)
Exemplo n.º 13
0
def main():
    parser = argparse.ArgumentParser(description=help_txt, formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('--output_ids', help='Comma-separated list of IDs of the analysis outputs to include in the website.', required=False, default=None)
    parser.add_argument('--upload_to', help='Optional kachery to upload to', required=False, default=None)
    parser.add_argument('--dest_key_path', help='Optional destination key path', required=False, default=None)

    args = parser.parse_args()

    if args.upload_to:
        upload_to = args.upload_to.split(',')
    else:
        upload_to = None

    mt.configDownloadFrom(['spikeforest.kbucket', 'spikeforest.public'])

    if args.output_ids is not None:
        output_ids = args.output_ids.split(',')
    else:
        output_ids = [
            'paired_boyden32c',
            'paired_crcns',
            'paired_mea64c',
            'paired_kampff',
            'paired_monotrode',
            'synth_monotrode',
            'synth_bionet',
            'synth_magland',
            'manual_franklab',
            'synth_mearec_neuronexus',
            'synth_mearec_tetrode',
            # 'SYNTH_MEAREC_TETRODE_b',
            # 'SYNTH_MEAREC_NEURONEXUS_b',
            # 'SYNTH_MEAREC_NEUROPIXELS_b',
            'synth_visapy',
            'hybrid_janelia'
        ]
    print('Using output ids: ', output_ids)

    sorters_to_include = set([
        'HerdingSpikes2',
        'IronClust',
        'IronClust1',
        'IronClust2',
        'IronClust3',
        'IronClust4',
        'IronClust5',
        'IronClust6',
        'IronClust7',
        'IronClust8',
        'JRClust',
        'KiloSort',
        'KiloSort2',
        'Klusta',
        'MountainSort4',
        'SpykingCircus',
        'Tridesclous',
        'Waveclus',
        # 'Yass'
    ])

    print('******************************** LOADING ANALYSIS OUTPUT OBJECTS...')
    studies = []
    study_sets = []
    recordings = []
    sorting_results = []
    for output_id in output_ids:
        print('Loading output object: {}'.format(output_id))
        output_path = ('key://pairio/spikeforest/spikeforest_analysis_results.{}.json').format(output_id)
        obj = mt.loadObject(path=output_path)
        if obj is not None:
            studies = studies + obj['studies']
            print(obj.keys())
            study_sets = study_sets + obj.get('study_sets', [])
            recordings = recordings + obj['recordings']
            sorting_results = sorting_results + obj['sorting_results']
        else:
            raise Exception('Unable to load: {}'.format(output_path))

    # ALGORITHMS
    print('******************************** ASSEMBLING ALGORITHMS...')
    algorithms_by_processor_name = dict()
    Algorithms = []
    basepath = '../../spikeforest/spikeforestsorters/descriptions'
    repo_base_url = 'https://github.com/flatironinstitute/spikeforest/blob/master'
    for item in os.listdir(basepath):
        if item.endswith('.md'):
            alg = frontmatter.load(basepath + '/' + item).to_dict()
            alg['markdown_link'] = repo_base_url + '/spikeforest/spikeforestsorters/descriptions/' + item
            alg['markdown'] = alg['content']
            del alg['content']
            if 'processor_name' in alg:
                algorithms_by_processor_name[alg['processor_name']] = alg
            Algorithms.append(alg)
    print([alg['label'] for alg in Algorithms])

    # # STUDIES
    # print('******************************** ASSEMBLING STUDIES...')
    # studies_by_name = dict()
    # for study in studies:
    #     studies_by_name[study['name']] = study
    #     study['recordings'] = []
    # for recording in recordings:
    #     studies_by_name[recording['studyName']]['recordings'].append(dict(
    #         name=recording['name'],
    #         studyName=recording['study_name'],
    #         directory=recording['directory'],
    #     ))

    Studies = []
    for study in studies:
        Studies.append(dict(
            name=study['name'],
            studySet=study['study_set'],
            description=study['description'],
            recordings=[]
            # the following can be obtained from the other collections
            # numRecordings, sorters, etc...
        ))
    print([S['name'] for S in Studies])

    print('******************************** ASSEMBLING STUDY SETS...')
    study_sets_by_name = dict()
    for study_set in study_sets:
        study_sets_by_name[study_set['name']] = study_set
        study_set['studies'] = []
    studies_by_name = dict()
    for study in studies:
        study0 = dict(
            name=study['name'],
            studySetName=study['study_set'],
            recordings=[]
        )
        study_sets_by_name[study['study_set']]['studies'].append(study0)
        studies_by_name[study0['name']] = study0
    for recording in recordings:
        true_units_info = mt.loadObject(path=recording['summary']['true_units_info'])
        if not true_units_info:
            print(recording['summary']['true_units_info'])
            raise Exception('Unable to load true_units_info for recording {}'.format(recording['name']))
        recording0 = dict(
            name=recording['name'],
            studyName=recording['study'],
            studySetName=studies_by_name[recording['study']]['studySetName'],
            directory=recording['directory'],
            firingsTrue=recording['firings_true'],
            sampleRateHz=recording['summary']['computed_info']['samplerate'],
            numChannels=recording['summary']['computed_info']['num_channels'],
            durationSec=recording['summary']['computed_info']['duration_sec'],
            numTrueUnits=len(true_units_info),
            spikeSign=-1  # TODO: set this properly
        )
        studies_by_name[recording0['studyName']]['recordings'].append(recording0)
    StudySets = []
    for study_set in study_sets:
        StudySets.append(study_set)
    print(StudySets)

    # SORTING RESULTS
    print('******************************** SORTING RESULTS...')
    SortingResults = []
    for sr in sorting_results:
        if sr['sorter']['name'] in sorters_to_include:
            SR = dict(
                recordingName=sr['recording']['name'],
                studyName=sr['recording']['study'],
                sorterName=sr['sorter']['name'],
                recordingDirectory=sr['recording']['directory'],
                firingsTrue=sr['recording']['firings_true'],
                consoleOut=sr['console_out'],
                container=sr['container'],
                cpuTimeSec=sr['execution_stats'].get('elapsed_sec', None),
                returnCode=sr['execution_stats'].get('retcode', 0),  # TODO: in future, the default should not be 0 -- rather it should be a required field of execution_stats
                timedOut=sr['execution_stats'].get('timed_out', False),
                startTime=datetime.fromtimestamp(sr['execution_stats'].get('start_time')).isoformat(),
                endTime=datetime.fromtimestamp(sr['execution_stats'].get('end_time')).isoformat()
            )
            if sr.get('firings', None):
                SR['firings'] = sr['firings']
                if not sr.get('comparison_with_truth', None):
                    print('Warning: comparison with truth not found for sorting result: {} {}/{}'.format(sr['sorter']['name'], sr['recording']['study'], sr['recording']['name']))
                    print('Console output is here: ' + sr['console_out'])
            else:
                print('Warning: firings not found for sorting result: {} {}/{}'.format(sr['sorter']['name'], sr['recording']['study'], sr['recording']['name']))
                print('Console output is here: ' + sr['console_out'])
            SortingResults.append(SR)
    # print('Num unit results:', len(UnitResults))

    # SORTERS
    print('******************************** ASSEMBLING SORTERS...')
    sorters_by_name = dict()
    for sr in sorting_results:
        sorters_by_name[sr['sorter']['name']] = sr['sorter']
    Sorters = []
    sorter_names = sorted(list(sorters_by_name.keys()))
    sorter_names = [sorter_name for sorter_name in sorter_names if sorter_name in sorters_to_include]
    for sorter_name in sorter_names:
        sorter = sorters_by_name[sorter_name]
        alg = algorithms_by_processor_name.get(sorter['processor_name'], dict())
        alg_label = alg.get('label', sorter['processor_name'])
        if sorter['name'] in sorters_to_include:
            Sorters.append(dict(
                name=sorter['name'],
                algorithmName=alg_label,
                processorName=sorter['processor_name'],
                processorVersion='0',  # jfm to provide this
                sortingParameters=sorter['params']
            ))
    print([S['name'] + ':' + S['algorithmName'] for S in Sorters])

    # STUDY ANALYSIS RESULTS
    print('******************************** ASSEMBLING STUDY ANALYSIS RESULTS...')
    StudyAnalysisResults = [
        _assemble_study_analysis_result(
            study_name=study['name'],
            study_set_name=study['study_set'],
            recordings=recordings,
            sorting_results=sorting_results,
            sorter_names=sorter_names
        )
        for study in studies
    ]

    # GENERAL
    print('******************************** ASSEMBLING GENERAL INFO...')
    General = [dict(
        dateUpdated=datetime.now().isoformat(),
        packageVersions=dict(
            mountaintools=pkg_resources.get_distribution("mountaintools").version,
            spikeforest=pkg_resources.get_distribution("spikeforest").version
        )
    )]

    obj = dict(
        mode='spike-front',
        StudySets=StudySets,
        # TrueUnits=TrueUnits,
        # UnitResults=UnitResults,
        SortingResults=SortingResults,
        Sorters=Sorters,
        Algorithms=Algorithms,
        StudyAnalysisResults=StudyAnalysisResults,
        General=General
    )
    address = mt.saveObject(object=obj)
    mt.createSnapshot(path=address, upload_to=upload_to, dest_path=args.dest_key_path)
Exemplo n.º 14
0
def genjob(*, processor: mlpr.Processor, fname: str,
           processor_args: dict) -> None:
    job = processor.createJob(**processor_args)
    mt.saveObject(object=job.getObject(), dest_path=fname)
Exemplo n.º 15
0
def install_jrclust(repo, commit):
    spikeforest_alg_install_path = get_install_path()
    key = dict(alg='jrclust', repo=repo, commit=commit)
    source_path = spikeforest_alg_install_path + '/jrclust_' + commit
    if os.path.exists(source_path):
        # The dir hash method does not seem to be working for some reason here
        # hash0 = mt.computeDirHash(source_path)
        # if hash0 == mt.getValue(key=key):
        #     print('jrclust is already auto-installed.')
        #     return source_path

        a = mt.loadObject(path=source_path + '/spikeforest.json')
        if a:
            if mt.sha1OfObject(a) == mt.sha1OfObject(key):
                print('jrclust is already auto-installed.')
                return source_path

        print('Removing directory: {}'.format(source_path))
        shutil.rmtree(source_path)

    script = """
    #!/bin/bash
    set -e

    git clone {repo} {source_path}
    cd {source_path}
    git checkout {commit}
    """.format(repo=repo, commit=commit, source_path=source_path)
    ss = mlpr.ShellScript(script=script)
    ss.start()
    retcode = ss.wait()
    if retcode != 0:
        raise Exception('Install script returned a non-zero exit code/')

    compile_gpu = mlpr.ShellScript(script="""
    function compile_gpu

    try
        jrc compile
    catch
        disp('Problem running `jrc compile`');
        disp(lasterr());
        exit(-1)
    end;
    exit(0)
    """)
    compile_gpu.write(script_path=source_path + '/compile_gpu.m')

    script = """
    #!/bin/bash
    set -e

    cd {source_path}
    matlab -nodisplay -nosplash -r "compile_gpu"
    """.format(source_path=source_path)
    ss = mlpr.ShellScript(script=script)
    ss.start()
    retcode = ss.wait()
    if retcode != 0:
        raise Exception('Compute gpu script returned a non-zero exit code.')

    # The dir hash method does not seem to be working for some reason here
    # hash0 = mt.computeDirHash(source_path)
    # mt.setValue(key=key, value=hash0)
    mt.saveObject(object=key, dest_path=source_path + '/spikeforest.json')

    return source_path
Exemplo n.º 16
0
def main():
    parser = argparse.ArgumentParser(
        description=help_txt, formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('--output_dir',
                        help='The output directory for saving the files.')
    parser.add_argument('--download-from',
                        help='The output directory for saving the files.',
                        required=False,
                        default='spikeforest.kbucket')
    parser.add_argument('--key_path', help='Key path to retrieve data from')

    args = parser.parse_args()

    output_dir = args.output_dir

    if os.path.exists(output_dir):
        raise Exception(
            'Output directory already exists: {}'.format(output_dir))

    os.mkdir(output_dir)

    mt.configDownloadFrom(args.download_from)

    print('Loading spike-front results object...')
    obj = mt.loadObject(path=args.key_path)

    StudySets = obj['StudySets']
    SortingResults = obj['SortingResults']
    Sorters = obj['Sorters']
    Algorithms = obj['Algorithms']
    StudyAnalysisResults = obj['StudyAnalysisResults']
    General = obj['General']

    print('Saving {} study sets to {}/StudySets.json'.format(
        len(StudySets), output_dir))
    mt.saveObject(object=StudySets, dest_path=output_dir + '/StudySets.json')

    print('Saving {} sorting results to {}/SortingResults.json'.format(
        len(SortingResults), output_dir))
    mt.saveObject(object=SortingResults,
                  dest_path=output_dir + '/SortingResults.json')

    print('Saving {} sorters to {}/Sorters.json'.format(
        len(Sorters), output_dir))
    mt.saveObject(object=Sorters, dest_path=output_dir + '/Sorters.json')

    print('Saving {} algorithms to {}/Algorithms.json'.format(
        len(Algorithms), output_dir))
    mt.saveObject(object=Algorithms, dest_path=output_dir + '/Algorithms.json')

    print('Saving {} study analysis results to {}/StudyAnalysisResults.json'.
          format(len(StudySets), output_dir))
    mt.saveObject(object=StudyAnalysisResults,
                  dest_path=output_dir + '/StudyAnalysisResults.json')

    print('Saving general info to {}/General.json'.format(output_dir))
    mt.saveObject(object=General, dest_path=output_dir + '/General.json')
Exemplo n.º 17
0
def apply_sorters_to_recordings(*, label, sorters, recordings, studies, study_sets, output_id=None, output_path=None, job_timeout=60 * 20, upload_to=None, skip_failing=None):
    # Summarize the recordings
    mtlogging.sublog('summarize-recordings')
    recordings = sa.summarize_recordings(
        recordings=recordings,
        compute_resource='default',
        label='Summarize recordings ({})'.format(label),
        upload_to=upload_to
    )

    # Run the spike sorting
    mtlogging.sublog('sorting')
    sorting_results = sa.multi_sort_recordings(
        sorters=sorters,
        recordings=recordings,
        label='Sort recordings ({})'.format(label),
        job_timeout=job_timeout,
        upload_to=upload_to,
        skip_failing=skip_failing
    )

    # Summarize the sortings
    mtlogging.sublog('summarize-sortings')
    sorting_results = sa.summarize_sortings(
        sortings=sorting_results,
        compute_resource='default',
        label='Summarize sortings ({})'.format(label)
    )

    # Compare with ground truth
    mtlogging.sublog('compare-with-truth')
    sorting_results = sa.compare_sortings_with_truth(
        sortings=sorting_results,
        compute_resource='default',
        label='Compare with truth ({})'.format(label),
        upload_to=upload_to
    )

    # Aggregate the results
    mtlogging.sublog('aggregate')
    aggregated_sorting_results = sa.aggregate_sorting_results(
        studies, recordings, sorting_results)

    output_object = dict(
        studies=studies,
        recordings=recordings,
        study_sets=study_sets,
        sorting_results=sorting_results,
        aggregated_sorting_results=mt.saveObject(
            object=aggregated_sorting_results, upload_to=upload_to)
    )

    # Save the output
    if output_id:
        print('Saving the output')
        mtlogging.sublog('save-output')
        mt.saveObject(
            key=dict(
                name='spikeforest_results'
            ),
            subkey=output_id,
            object=output_object,
            upload_to=upload_to
        )

    if output_path:
        print('Saving the output to {}'.format(output_path))
        mtlogging.sublog('save-output-path')
        address = mt.saveObject(output_object, upload_to=upload_to)
        if not address:
            raise Exception('Problem saving output object.')
        if not mt.createSnapshot(path=address, dest_path=output_path):
            raise Exception('Problem saving output to {}'.format(output_path))

    mtlogging.sublog('show-output-summary')
    for sr in aggregated_sorting_results['study_sorting_results']:
        study_name = sr['study']
        sorter_name = sr['sorter']
        n1 = np.array(sr['num_matches'])
        n2 = np.array(sr['num_false_positives'])
        n3 = np.array(sr['num_false_negatives'])
        accuracies = n1 / (n1 + n2 + n3)
        avg_accuracy = np.mean(accuracies)
        txt = 'STUDY: {}, SORTER: {}, AVG ACCURACY: {}'.format(
            study_name, sorter_name, avg_accuracy)
        print(txt)
Exemplo n.º 18
0
# ... or retrieve the path to a local file containing the text
fname = mt.realizeFile(path=path)
print(fname)
# Output: /tmp/sha1-cache/4/82/482cb0cfcbed6740a2bcb659c9ccc22a4d27b369

# Or we can store some large text by key and retrieve it later
mt.saveText(key=dict(name='key-for-repeating-text'),
            text='some large repeating text'*100)
txt = mt.loadText(key=dict(name='key-for-repeating-text'))
print(len(txt))  # Output: 2500

print('===================')

# Similarly we can store python dicts via json content
path = mt.saveObject(dict(some='object'), basename='object.json')
print(path)
# Output: sha1://b77fdda467b03d7a0c3e06f6f441f689ac46e817/object.json

retrieved_object = mt.loadObject(path=path)
print(retrieved_object)

# Or store objects by key
mt.saveObject(object=dict(some_other='object'), key=dict(some='key'))
obj = mt.loadObject(key=dict(some='key'))
print(obj)

print('===================')

# You can do the same with files
with open('test___.txt', 'w') as f:
Exemplo n.º 19
0
def _set_job_status(job, status):
    mt.saveObject(key=dict(name='job_status', job=job), object=status)