def add_job(self, job): self._jobs.append(job) sorter_name = job['sorterName'] recording_path = job['recordingPath'] sorter = getattr(sorters, sorter_name) with hither.config(container='default'), hither.job_queue(): sorting_result = sorter.run(recording_path=recording_path, sorting_out=hither.File()) job['status'] = 'finished' for handler in self._job_updated_handlers: handler(job)
def example1_parallel(): results = [] job_handler = hither.ParallelJobHandler(10) with hither.job_queue(), hither.config(container='default', job_handler=job_handler): for n in range(501, 511): result = hello_hither_scipy.run(n=n) setattr(result, 'n', n) results.append(result) for result in results: n = result.n elapsed_sec = result.runtime_info['elapsed_sec'] retval = result.retval print(f'n={n}: result={retval}; elapsed(sec)={elapsed_sec}')
def test_sort(sorter_name, min_avg_accuracy, recording_path, sorting_true_path, num_jobs=1, job_handler=None, container='default'): from spikeforest2 import sorters from spikeforest2 import processing import kachery as ka # for now, in this test, don't use gpu for irc gpu = sorter_name in ['kilosort2', 'kilosort', 'tridesclous', 'ironclust'] sorting_results = [] with ka.config(fr='default_readonly'): with hither.config(container=container, gpu=gpu, job_handler=job_handler), hither.job_queue(): sorter = getattr(sorters, sorter_name) for _ in range(num_jobs): sorting_result = sorter.run(recording_path=recording_path, sorting_out=hither.File()) sorting_results.append(sorting_result) assert sorting_result.success sorting_result = sorting_results[0] with ka.config(fr='default_readonly'): with hither.config(container='default', gpu=False): compare_result = processing.compare_with_truth.run( sorting_path=sorting_result.outputs.sorting_out, sorting_true_path=sorting_true_path, json_out=hither.File()) assert compare_result.success obj = ka.load_object(compare_result.outputs.json_out._path) aa = _average_accuracy(obj) print(F'AVERAGE-ACCURACY: {aa}') assert aa >= min_avg_accuracy, f"Average accuracy is lower than expected {aa} < {min_avg_accuracy}" print('Passed.')
def main(): from spikeforest2 import sorters from spikeforest2 import processing parser = argparse.ArgumentParser(description='Run the SpikeForest2 main analysis') # parser.add_argument('analysis_file', help='Path to the analysis specification file (.json format).') # parser.add_argument('--config', help='Configuration file', required=True) # parser.add_argument('--output', help='Analysis output file (.json format)', required=True) # parser.add_argument('--slurm', help='Optional SLURM configuration file (.json format)', required=False, default=None) # parser.add_argument('--verbose', help='Provide some additional verbose output.', action='store_true') parser.add_argument('spec', help='Path to the .json file containing the analysis specification') parser.add_argument('--output', '-o', help='The output .json file', required=True) parser.add_argument('--force-run', help='Force rerunning of all spike sorting', action='store_true') parser.add_argument('--force-run-all', help='Force rerunning of all spike sorting and other processing', action='store_true') parser.add_argument('--parallel', help='Optional number of parallel jobs', required=False, default='0') parser.add_argument('--slurm', help='Path to slurm config file', required=False, default=None) parser.add_argument('--cache', help='The cache database to use', required=False, default=None) parser.add_argument('--rerun-failing', help='Rerun sorting jobs that previously failed', action='store_true') parser.add_argument('--test', help='Only run a few.', action='store_true') parser.add_argument('--job-timeout', help='Timeout for sorting jobs', required=False, default=600) parser.add_argument('--log-file', help='Log file for analysis progress', required=False, default=None) args = parser.parse_args() force_run_all = args.force_run_all # the following apply to sorting jobs only force_run = args.force_run or args.force_run_all job_timeout = float(args.job_timeout) cache_failing = True rerun_failing = args.rerun_failing with open(args.spec, 'r') as f: spec = json.load(f) # clear the log file if args.log_file is not None: with open(args.log_file, 'w'): pass studysets_path = spec['studysets'] studyset_names = spec['studyset_names'] spike_sorters = spec['spike_sorters'] ka.set_config(fr='default_readonly') print(f'Loading study sets object from: {studysets_path}') studysets_obj = ka.load_object(studysets_path) if not studysets_obj: raise Exception(f'Unable to load: {studysets_path}') all_study_sets = studysets_obj['StudySets'] study_sets = [] for studyset in all_study_sets: if studyset['name'] in studyset_names: study_sets.append(studyset) if int(args.parallel) > 0: job_handler = hither.ParallelJobHandler(int(args.parallel)) job_handler_gpu = job_handler job_handler_ks = job_handler elif args.slurm: with open(args.slurm, 'r') as f: slurm_config = json.load(f) job_handler = hither.SlurmJobHandler( working_dir='tmp_slurm', **slurm_config['cpu'] ) job_handler_gpu = hither.SlurmJobHandler( working_dir='tmp_slurm', **slurm_config['gpu'] ) job_handler_ks = hither.SlurmJobHandler( working_dir='tmp_slurm', **slurm_config['ks'] ) else: job_handler = None job_handler_gpu = None job_handler_ks = None with hither.config( container='default', cache=args.cache, force_run=force_run_all, job_handler=job_handler, log_path=args.log_file ), hither.job_queue(): studies = [] recordings = [] for studyset in study_sets: studyset_name = studyset['name'] print(f'================ STUDY SET: {studyset_name}') studies0 = studyset['studies'] if args.test: studies0 = studies0[:1] studyset['studies'] = studies0 for study in studies0: study['study_set'] = studyset_name study_name = study['name'] print(f'======== STUDY: {study_name}') recordings0 = study['recordings'] if args.test: recordings0 = recordings0[:2] study['recordings'] = recordings0 for recording in recordings0: recording['study'] = study_name recording['study_set'] = studyset_name recording['firings_true'] = recording['firingsTrue'] recordings.append(recording) studies.append(study) # Download recordings for recording in recordings: print(f'Downloading recording: {recording["study"]}/{recording["name"]}') ka.load_file(recording['directory'] + '/raw.mda') ka.load_file(recording['directory'] + '/params.json') ka.load_file(recording['directory'] + '/geom.csv') ka.load_file(recording['directory'] + '/firings_true.mda') # Attach results objects for recording in recordings: recording['results'] = dict() # Summarize recordings for recording in recordings: recording_path = recording['directory'] sorting_true_path = recording['firingsTrue'] recording['results']['computed-info'] = processing.compute_recording_info.run( _label=f'compute-recording-info:{recording["study"]}/{recording["name"]}', recording_path=recording_path, json_out=hither.File() ) recording['results']['true-units-info'] = processing.compute_units_info.run( _label=f'compute-units-info:{recording["study"]}/{recording["name"]}', recording_path=recording_path, sorting_path=sorting_true_path, json_out=hither.File() ) # Spike sorting for sorter in spike_sorters: for recording in recordings: if recording['study_set'] in sorter['studysets']: recording_path = recording['directory'] sorting_true_path = recording['firingsTrue'] algorithm = sorter['processor_name'] if not hasattr(sorters, algorithm): raise Exception(f'No such sorting algorithm: {algorithm}') Sorter = getattr(sorters, algorithm) if algorithm in ['ironclust-disable', 'tridesclous']: gpu = True jh = job_handler_gpu elif algorithm in ['kilosort', 'kilosort2']: gpu = True jh = job_handler_ks else: gpu = False jh = job_handler with hither.config(gpu=gpu, force_run=force_run, exception_on_fail=False, cache_failing=cache_failing, rerun_failing=rerun_failing, job_handler=jh, job_timeout=job_timeout): sorting_result = Sorter.run( _label=f'{algorithm}:{recording["study"]}/{recording["name"]}', recording_path=recording['directory'], sorting_out=hither.File() ) recording['results']['sorting-' + sorter['name']] = sorting_result recording['results']['comparison-with-truth-' + sorter['name']] = processing.compare_with_truth.run( _label=f'comparison-with-truth:{algorithm}:{recording["study"]}/{recording["name"]}', sorting_path=sorting_result.outputs.sorting_out, sorting_true_path=sorting_true_path, json_out=hither.File() ) recording['results']['units-info-' + sorter['name']] = processing.compute_units_info.run( _label=f'units-info:{algorithm}:{recording["study"]}/{recording["name"]}', recording_path=recording_path, sorting_path=sorting_result.outputs.sorting_out, json_out=hither.File() ) # Assemble all of the results print('') print('=======================================================') print('Assembling results...') for recording in recordings: print(f'Assembling recording: {recording["study"]}/{recording["name"]}') recording['summary'] = dict( plots=dict(), computed_info=ka.load_object(recording['results']['computed-info'].outputs.json_out._path), true_units_info=ka.store_file(recording['results']['true-units-info'].outputs.json_out._path) ) sorting_results = [] for sorter in spike_sorters: for recording in recordings: if recording['study_set'] in sorter['studysets']: print(f'Assembling sorting: {sorter["processor_name"]} {recording["study"]}/{recording["name"]}') sorting_result = recording['results']['sorting-' + sorter['name']] comparison_result = recording['results']['comparison-with-truth-' + sorter['name']] units_info_result = recording['results']['units-info-' + sorter['name']] console_out_str = _console_out_to_str(sorting_result.runtime_info['console_out']) console_out_path = ka.store_text(console_out_str) sr = dict( recording=recording, sorter=sorter, firings_true=recording['directory'] + '/firings_true.mda', processor_name=sorter['processor_name'], processor_version=sorting_result.version, sorting_parameters=sorter['params'], execution_stats=dict( start_time=sorting_result.runtime_info['start_time'], end_time=sorting_result.runtime_info['end_time'], elapsed_sec=sorting_result.runtime_info['end_time'] - sorting_result.runtime_info['start_time'], reported_elapsed_sec=_parse_spikeforest_runtime(console_out_str), retcode=0 if sorting_result.success else -1, timed_out=sorting_result.runtime_info.get('timed_out', False) ), container=sorting_result.container, console_out=console_out_path ) if sorting_result.success: sr['firings'] = ka.store_file(sorting_result.outputs.sorting_out._path) sr['comparison_with_truth'] = dict( json=ka.store_file(comparison_result.outputs.json_out._path) ) sr['sorted_units_info'] = ka.store_file(units_info_result.outputs.json_out._path) else: sr['firings'] = None sr['comparison_with_truth'] = None sr['sorted_units_info'] = None sorting_results.append(sr) # Delete results from recordings for recording in recordings: del recording['results'] # Aggregate sorting results print('') print('=======================================================') print('Aggregating sorting results...') aggregated_sorting_results = aggregate_sorting_results(studies, recordings, sorting_results) # Show output summary for sr in aggregated_sorting_results['study_sorting_results']: study_name = sr['study'] sorter_name = sr['sorter'] n1 = np.array(sr['num_matches']) n2 = np.array(sr['num_false_positives']) n3 = np.array(sr['num_false_negatives']) accuracies = n1 / (n1 + n2 + n3) avg_accuracy = np.mean(accuracies) txt = 'STUDY: {}, SORTER: {}, AVG ACCURACY: {}'.format(study_name, sorter_name, avg_accuracy) print(txt) output_object = dict( studies=studies, recordings=recordings, study_sets=study_sets, sorting_results=sorting_results, aggregated_sorting_results=ka.store_object(aggregated_sorting_results, basename='aggregated_sorting_results.json') ) print(f'Writing output to {args.output}...') with open(args.output, 'w') as f: json.dump(output_object, f, indent=4) print('Done.')
def main(): from mountaintools import client as mt parser = argparse.ArgumentParser( description= 'Generate unit detail data (including spikesprays) for website') parser.add_argument('analysis_path', help='assembled analysis file (output.json)') parser.add_argument( '--studysets', help='Comma-separated list of study set names to include', required=False, default=None) parser.add_argument('--force-run', help='Force rerunning of processing', action='store_true') parser.add_argument( '--force-run-all', help='Force rerunning of processing including filtering', action='store_true') parser.add_argument('--parallel', help='Optional number of parallel jobs', required=False, default='0') parser.add_argument('--slurm', help='Path to slurm config file', required=False, default=None) parser.add_argument('--cache', help='The cache database to use', required=False, default=None) parser.add_argument('--job-timeout', help='Timeout for processing jobs', required=False, default=600) parser.add_argument('--log-file', help='Log file for analysis progress', required=False, default=None) parser.add_argument( '--force-regenerate', help= 'Whether to force regenerating spike sprays (for when code has changed)', action='store_true') parser.add_argument('--test', help='Whether to just test by running only 1', action='store_true') args = parser.parse_args() mt.configDownloadFrom(['spikeforest.kbucket']) with open(args.analysis_path, 'r') as f: analysis = json.load(f) if args.studysets is not None: studyset_names = args.studysets.split(',') print('Using study sets: ', studyset_names) else: studyset_names = None study_sets = analysis['StudySets'] sorting_results = analysis['SortingResults'] studies_to_include = [] for ss in study_sets: if (studyset_names is None) or (ss['name'] in studyset_names): for study in ss['studies']: studies_to_include.append(study['name']) print('Including studies:', studies_to_include) print('Determining sorting results to process ({} total)...'.format( len(sorting_results))) sorting_results_to_process = [] sorting_results_to_consider = [] for sr in sorting_results: study_name = sr['studyName'] if study_name in studies_to_include: if 'firings' in sr: if sr.get('comparisonWithTruth', None) is not None: sorting_results_to_consider.append(sr) key = dict(name='unit-details-v0.1.0', recording_directory=sr['recordingDirectory'], firings_true=sr['firingsTrue'], firings=sr['firings']) val = mt.getValue(key=key, collection='spikeforest') if (not val) or (args.force_regenerate): sr['key'] = key sorting_results_to_process.append(sr) if args.test and len(sorting_results_to_process) > 0: sorting_results_to_process = [sorting_results_to_process[0]] print('Need to process {} of {} sorting results'.format( len(sorting_results_to_process), len(sorting_results_to_consider))) recording_directories_to_process = sorted( list( set([ sr['recordingDirectory'] for sr in sorting_results_to_process ]))) print('{} recording directories to process'.format( len(recording_directories_to_process))) if int(args.parallel) > 0: job_handler = hither.ParallelJobHandler(int(args.parallel)) elif args.slurm: with open(args.slurm, 'r') as f: slurm_config = json.load(f) job_handler = hither.SlurmJobHandler(working_dir='tmp_slurm', **slurm_config['cpu']) else: job_handler = None print('Filtering recordings...') filter_results = [] with hither.config(container='default', cache=args.cache, force_run=args.force_run_all, job_handler=job_handler, log_path=args.log_file, exception_on_fail=True, cache_failing=False, rerun_failing=True, job_timeout=args.job_timeout), hither.job_queue(): for recdir in recording_directories_to_process: result = filter_recording.run(recording_directory=recdir, timeseries_out=hither.File()) filter_results.append(result) filtered_timeseries_by_recdir = dict() for i, recdir in enumerate(recording_directories_to_process): result0 = filter_results[i] if not result0.success: raise Exception( 'Problem computing filtered timeseries for recording: {}'. format(recdir)) filtered_timeseries_by_recdir[ recdir] = result0.outputs.timeseries_out._path print('Creating spike sprays...') with hither.config(container='default', cache=args.cache, force_run=args.force_run or args.force_run_all, job_handler=job_handler, log_path=args.log_file, exception_on_fail=True, cache_failing=False, rerun_failing=True, job_timeout=args.job_timeout), hither.job_queue(): for sr in sorting_results_to_process: recdir = sr['recordingDirectory'] study_name = sr['studyName'] rec_name = sr['recordingName'] sorter_name = sr['sorterName'] print('====== COMPUTING {}/{}/{}'.format(study_name, rec_name, sorter_name)) cwt = ka.load_object(path=sr['comparisonWithTruth']['json']) filtered_timeseries = filtered_timeseries_by_recdir[recdir] spike_spray_results = [] list0 = list(cwt.values()) for _, unit in enumerate(list0): result = create_spike_sprays.run( recording_directory=recdir, filtered_timeseries=filtered_timeseries, firings_true=os.path.join(recdir, 'firings_true.mda'), firings_sorted=sr['firings'], unit_id_true=unit['unit_id'], unit_id_sorted=unit['best_unit'], json_out=hither.File()) setattr(result, 'unit', unit) spike_spray_results.append(result) sr['spike_spray_results'] = spike_spray_results for sr in sorting_results_to_process: recdir = sr['recordingDirectory'] study_name = sr['studyName'] rec_name = sr['recordingName'] sorter_name = sr['sorterName'] print('====== SAVING {}/{}/{}'.format(study_name, rec_name, sorter_name)) spike_spray_results = sr['spike_spray_results'] key = sr['key'] unit_details = [] ok = True for i, result in enumerate(spike_spray_results): if not result.success: print( 'WARNING: Error creating spike sprays for {}/{}/{}'.format( study_name, rec_name, sorter_name)) ok = False break ssobj = ka.load_object(result.outputs.json_out._path) if ssobj is None: raise Exception('Problem loading spikespray object output.') address = mt.saveObject(object=ssobj, upload_to='spikeforest.kbucket') unit = getattr(result, 'unit') unit_details.append( dict(studyName=study_name, recordingName=rec_name, sorterName=sorter_name, trueUnitId=unit['unit_id'], sortedUnitId=unit['best_unit'], spikeSprayUrl=mt.findFile( path=address, remote_only=True, download_from='spikeforest.kbucket'))) if ok: mt.saveObject(collection='spikeforest', key=key, object=unit_details, upload_to='spikeforest.public')