Ejemplo n.º 1
0
def example2():
    with hither.config(container='default'):
        result1 = intentional_exception.run(raise_exception=False)
        try:
            intentional_exception.run(raise_exception=True)
            got_exception = False
        except:
            got_exception = True
        assert got_exception, "Did not get exception"
        with hither.config(exception_on_fail=False):
            result3 = intentional_exception.run(raise_exception=True)
    assert result1.success, "result1.success should be True"
    assert (not result3.success), "result3.success should be False"
    print(result1.retval, result3.retval)
Ejemplo n.º 2
0
def test_sort(sorter_name,
              min_avg_accuracy,
              recording_path,
              sorting_true_path,
              num_jobs=1,
              job_handler=None,
              container='default'):
    from spikeforest2 import sorters
    from spikeforest2 import processing
    import kachery as ka

    # for now, in this test, don't use gpu for irc
    gpu = sorter_name in ['kilosort2', 'kilosort', 'tridesclous', 'ironclust']

    sorting_results = []
    with ka.config(fr='default_readonly'):
        with hither.config(container=container,
                           gpu=gpu,
                           job_handler=job_handler), hither.job_queue():
            sorter = getattr(sorters, sorter_name)
            for _ in range(num_jobs):
                sorting_result = sorter.run(recording_path=recording_path,
                                            sorting_out=hither.File())
                sorting_results.append(sorting_result)

    assert sorting_result.success

    sorting_result = sorting_results[0]
    with ka.config(fr='default_readonly'):
        with hither.config(container='default', gpu=False):
            compare_result = processing.compare_with_truth.run(
                sorting_path=sorting_result.outputs.sorting_out,
                sorting_true_path=sorting_true_path,
                json_out=hither.File())

    assert compare_result.success

    obj = ka.load_object(compare_result.outputs.json_out._path)

    aa = _average_accuracy(obj)

    print(F'AVERAGE-ACCURACY: {aa}')

    assert aa >= min_avg_accuracy, f"Average accuracy is lower than expected {aa} < {min_avg_accuracy}"

    print('Passed.')
Ejemplo n.º 3
0
def sort(algorithm: str,
         recording_path: str,
         sorting_out: str = None,
         params: dict = None,
         container: str = 'default',
         git_annex_mode=True,
         use_singularity: bool = False,
         job_timeout: float = 3600) -> str:

    from spikeforest2 import sorters
    HITHER_USE_SINGULARITY = os.getenv('HITHER_USE_SINGULARITY')
    if HITHER_USE_SINGULARITY is None:
        HITHER_USE_SINGULARITY = False
    print('HITHER_USE_SINGULARITY: ' + HITHER_USE_SINGULARITY)
    if not hasattr(sorters, algorithm):
        raise Exception('Sorter not found: {}'.format(algorithm))
    sorter = getattr(sorters, algorithm)
    if algorithm in [
            'kilosort2', 'kilosort', 'ironclust', 'tridesclous', 'jrclust'
    ]:
        gpu = True
    else:
        gpu = False
    if not sorting_out:
        sorting_out = hither.File()
    if not recording_path.startswith(
            'sha1dir://') or not recording_path.startswith('sha1://'):
        if os.path.isfile(recording_path):
            recording_path = ka.store_file(recording_path)
        elif os.path.isdir(recording_path):
            recording_path = ka.store_dir(recording_path,
                                          git_annex_mode=git_annex_mode)
    if params is None:
        params = dict()
    params_hither = dict(gpu=gpu, container=container)
    if job_timeout is not None:
        params_hither['job_timeout'] = job_timeout
    with hither.config(**params_hither):
        result = sorter.run(recording_path=recording_path,
                            sorting_out=sorting_out,
                            **params)
    print('SORTING')
    print('==============================================')
    return ka.store_file(result.outputs.sorting_out._path,
                         basename='firings.mda')


# def set_params(sorter, params_file):
#     params = {}
#     names_float = ['detection_thresh']
#     with open(params_file, 'r') as myfile:
#         for line in myfile:
#             name, var = line.partition("=")[::2]
#             name = name.strip()

#             params[name.strip()] = var
#     sorter.set_params(**params)
Ejemplo n.º 4
0
def example3():
    X = hither.File()
    create_text_file.run(text='some-text', intentional_exception=False, output_file=X)
    print_file.run(input_file=X)

    with hither.config(exception_on_fail=False):
        X2 = hither.File()
        create_text_file.run(text='some-text', intentional_exception=True, output_file=X2)
        result = print_file.run(input_file=X2)
        print(result.success)
Ejemplo n.º 5
0
 def add_job(self, job):
     self._jobs.append(job)
     sorter_name = job['sorterName']
     recording_path = job['recordingPath']
     sorter = getattr(sorters, sorter_name)
     with hither.config(container='default'), hither.job_queue():
         sorting_result = sorter.run(recording_path=recording_path,
                                     sorting_out=hither.File())
     job['status'] = 'finished'
     for handler in self._job_updated_handlers:
         handler(job)
Ejemplo n.º 6
0
def example1_parallel():
    results = []
    job_handler = hither.ParallelJobHandler(10)
    with hither.job_queue(), hither.config(container='default', job_handler=job_handler):
        for n in range(501, 511):
            result = hello_hither_scipy.run(n=n)
            setattr(result, 'n', n)
            results.append(result)
    for result in results:
        n = result.n
        elapsed_sec = result.runtime_info['elapsed_sec']
        retval = result.retval
        print(f'n={n}: result={retval}; elapsed(sec)={elapsed_sec}')
Ejemplo n.º 7
0
def sort(algorithm: str, recording_path: str):
    from spikeforest2 import sorters
    if not hasattr(sorters, algorithm):
        raise Exception('Sorter not found: {}'.format(algorithm))
    sorter = getattr(sorters, algorithm)
    if algorithm in ['kilosort2', 'ironclust']:
        gpu = True
    else:
        gpu = False
    with hither.config(gpu=gpu):
        result = sorter.run(recording_path=recording_path,
                            sorting_out=hither.File())
    print('SORTING')
    print('==============================================')
    return ka.store_file(result.outputs.sorting_out._path,
                         basename='firings.mda')
Ejemplo n.º 8
0
def main():
    import spikeextractors as se
    from spikeforest2_utils import writemda32, AutoRecordingExtractor
    from sklearn.neighbors import NearestNeighbors
    from sklearn.cross_decomposition import PLSRegression
    import spikeforest_widgets as sw
    sw.init_electron()

    # bandpass filter
    with hither.config(container='default', cache='default_readwrite'):
        recobj2 = filter_recording.run(
            recobj=recobj,
            freq_min=300,
            freq_max=6000,
            freq_wid=1000
        ).retval
    
    detect_threshold = 3
    detect_interval = 200
    detect_interval_reference = 10
    detect_sign = -1
    num_events = 1000
    snippet_len = (200, 200)
    window_frac = 0.3
    num_passes = 20
    npca = 100
    max_t = 30000 * 100
    k = 20
    ncomp = 4
    
    R = AutoRecordingExtractor(recobj2)

    X = R.get_traces()
    
    sig = X.copy()
    if detect_sign < 0:
        sig = -sig
    elif detect_sign == 0:
        sig = np.abs(sig)
    sig = np.max(sig, axis=0)
    noise_level = np.median(np.abs(sig)) / 0.6745  # median absolute deviation (MAD)
    times_reference = detect_on_channel(sig, detect_threshold=noise_level*detect_threshold, detect_interval=detect_interval_reference, detect_sign=1, margin=1000)
    times_reference = times_reference[times_reference <= max_t]
    print(f'Num. reference events = {len(times_reference)}')

    snippets_reference = extract_snippets(X, reference_frames=times_reference, snippet_len=snippet_len)
    tt = np.linspace(-1, 1, snippets_reference.shape[2])
    window0 = np.exp(-tt**2/(2*window_frac**2))
    for j in range(snippets_reference.shape[0]):
        for m in range(snippets_reference.shape[1]):
            snippets_reference[j, m, :] = snippets_reference[j, m, :] * window0
    A_snippets_reference = snippets_reference.reshape(snippets_reference.shape[0], snippets_reference.shape[1] * snippets_reference.shape[2])

    print('PCA...')
    u, s, vh = np.linalg.svd(A_snippets_reference)
    components_reference = vh[0:npca, :].T
    features_reference = A_snippets_reference @ components_reference

    print('Setting up nearest neighbors...')
    nbrs = NearestNeighbors(n_neighbors=k + 1, algorithm='ball_tree').fit(features_reference)

    X_signal = np.zeros((R.get_num_channels(), R.get_num_frames()), dtype=np.float32)

    for passnum in range(num_passes):
        print(f'Pass {passnum}')
        sig = X.copy()
        if detect_sign < 0:
            sig = -sig
        elif detect_sign == 0:
            sig = np.abs(sig)
        sig = np.max(sig, axis=0)
        noise_level = np.median(np.abs(sig)) / 0.6745  # median absolute deviation (MAD)
        times = detect_on_channel(sig, detect_threshold=noise_level*detect_threshold, detect_interval=detect_interval, detect_sign=1, margin=1000)
        times = times[times <= max_t]
        print(f'Number of events: {len(times)}')
        if len(times) == 0:
            break
        snippets = extract_snippets(X, reference_frames=times, snippet_len=snippet_len)
        for j in range(snippets.shape[0]):
            for m in range(snippets.shape[1]):
                snippets[j, m, :] = snippets[j, m, :] * window0
        A_snippets = snippets.reshape(snippets.shape[0], snippets.shape[1] * snippets.shape[2])
        features = A_snippets @ components_reference
        
        print('Finding nearest neighbors...')
        distances, indices = nbrs.kneighbors(features)
        features2 = np.zeros(features.shape, dtype=features.dtype)
        print('PLS regression...')
        for j in range(features.shape[0]):
            print(f'{j+1} of {features.shape[0]}')
            inds0 = np.squeeze(indices[j, :])
            inds0 = inds0[1:] # TODO: it may not always be necessary to exclude the first -- how should we make that decision?
            f_neighbors = features_reference[inds0, :]
            pls = PLSRegression(n_components=ncomp)
            pls.fit(f_neighbors.T, features[j, :].T)
            features2[j, :] = pls.predict(f_neighbors.T).T
        A_snippets_denoised = features2 @ components_reference.T
        
        snippets_denoised = A_snippets_denoised.reshape(snippets.shape)

        for j in range(len(times)):
            t0 = times[j]
            snippet_denoised_0 = np.squeeze(snippets_denoised[j, :, :])
            X_signal[:, t0-snippet_len[0]:t0+snippet_len[1]] = X_signal[:, t0-snippet_len[0]:t0+snippet_len[1]] + snippet_denoised_0
            X[:, t0-snippet_len[0]:t0+snippet_len[1]] = X[:, t0-snippet_len[0]:t0+snippet_len[1]] - snippet_denoised_0

    S = np.concatenate((X_signal, X, R.get_traces()), axis=0)

    with hither.TemporaryDirectory() as tmpdir:
        raw_fname = tmpdir + '/raw.mda'
        writemda32(S, raw_fname)
        sig_recobj = recobj2.copy()
        sig_recobj['raw'] = ka.store_file(raw_fname)
    
    sw.TimeseriesView(recording=AutoRecordingExtractor(sig_recobj)).show()
Ejemplo n.º 9
0
sorter_name = 'kilosort2'
sorter = getattr(sorters, sorter_name)
params = {}

# Determine whether we are going to use gpu based on the name of the sorter
gpu = sorter_name in ['kilosort2', 'kilosort', 'tridesclous', 'ironclust']

# In the future we will check whether we have the correct version of the wrapper here
# Version: 0.1.5-w1

# Download the data (if needed)
ka.set_config(fr='default_readonly')
ka.load_file(recording_path + '/raw.mda')

# Run the spike sorting
with hither.config(container='default', gpu=gpu):
    sorting_result = sorter.run(recording_path=recording_path,
                                sorting_out=hither.File(),
                                **params)
assert sorting_result.success
sorting_path = sorting_result.outputs.sorting_out

# Compare with ground truth
with hither.config(container='default'):
    compare_result = processing.compare_with_truth.run(
        sorting_path=sorting_path,
        sorting_true_path=sorting_true_path,
        json_out=hither.File())
assert compare_result.success
obj = ka.load_object(compare_result.outputs.json_out._path)
Ejemplo n.º 10
0
sorter_name = 'kilosort2'
sorter = getattr(sorters, sorter_name)
params = {}

# Determine whether we are going to use gpu based on the name of the sorter
gpu = sorter_name in ['kilosort2', 'kilosort', 'tridesclous', 'ironclust']

# In the future we will check whether we have the correct version of the wrapper here
# Version: 0.1.5-w1

# Download the data (if needed)
ka.set_config(fr='default_readonly')
ka.load_file(recording_path + '/raw.mda')

# Run the spike sorting
with hither.config(container='docker://magland/sf-kilosort2:0.1.5', gpu=gpu):
    sorting_result = sorter.run(recording_path=recording_path,
                                sorting_out=hither.File(),
                                **params)
assert sorting_result.success
sorting_path = sorting_result.outputs.sorting_out

# Compare with ground truth
with hither.config(container='default'):
    compare_result = processing.compare_with_truth.run(
        sorting_path=sorting_path,
        sorting_true_path=sorting_true_path,
        json_out=hither.File())
assert compare_result.success
obj = ka.load_object(compare_result.outputs.json_out._path)
Ejemplo n.º 11
0
#!/usr/bin/env python

from spikeforest2 import sorters
import hither
import kachery as ka

recording_path = 'sha1://961f4a641af64dded4821610189f808f0192de4d/SYNTH_MEAREC_TETRODE/synth_mearec_tetrode_noise10_K10_C4/002_synth.json'

with ka.config(fr='default_readonly'):
    #with hither.config(cache='default_readwrite'):
    with hither.config(container='default', gpu=True):
        result = sorters.kilosort2.run(recording_path=recording_path,
                                       sorting_out=hither.File())

print(result.outputs.sorting_out)
Ejemplo n.º 12
0
#!/usr/bin/env python

from spikeforest2 import sorters
import hither
import kachery as ka

recording_path = 'sha1://961f4a641af64dded4821610189f808f0192de4d/SYNTH_MEAREC_TETRODE/synth_mearec_tetrode_noise10_K10_C4/002_synth.json'

with ka.config(fr='default_readonly'):
    #with hither.config(cache='default_readwrite'):
        with hither.config(container='default'):
            result = sorters.ironclust.run(
                recording=recording_path,
                sorting_out=hither.File()
            )

print(result.outputs.sorting_out)
Ejemplo n.º 13
0
def main():
    from spikeforest2 import sorters
    from spikeforest2 import processing

    parser = argparse.ArgumentParser(description='Run the SpikeForest2 main analysis')
    # parser.add_argument('analysis_file', help='Path to the analysis specification file (.json format).')
    # parser.add_argument('--config', help='Configuration file', required=True)
    # parser.add_argument('--output', help='Analysis output file (.json format)', required=True)
    # parser.add_argument('--slurm', help='Optional SLURM configuration file (.json format)', required=False, default=None)
    # parser.add_argument('--verbose', help='Provide some additional verbose output.', action='store_true')
    parser.add_argument('spec', help='Path to the .json file containing the analysis specification')
    parser.add_argument('--output', '-o', help='The output .json file', required=True)
    parser.add_argument('--force-run', help='Force rerunning of all spike sorting', action='store_true')
    parser.add_argument('--force-run-all', help='Force rerunning of all spike sorting and other processing', action='store_true')
    parser.add_argument('--parallel', help='Optional number of parallel jobs', required=False, default='0')    
    parser.add_argument('--slurm', help='Path to slurm config file', required=False, default=None)
    parser.add_argument('--cache', help='The cache database to use', required=False, default=None)
    parser.add_argument('--rerun-failing', help='Rerun sorting jobs that previously failed', action='store_true')
    parser.add_argument('--test', help='Only run a few.', action='store_true')
    parser.add_argument('--job-timeout', help='Timeout for sorting jobs', required=False, default=600)
    parser.add_argument('--log-file', help='Log file for analysis progress', required=False, default=None)

    args = parser.parse_args()
    force_run_all = args.force_run_all

    # the following apply to sorting jobs only
    force_run = args.force_run or args.force_run_all
    job_timeout = float(args.job_timeout)
    cache_failing = True
    rerun_failing = args.rerun_failing

    with open(args.spec, 'r') as f:
        spec = json.load(f)

    # clear the log file    
    if args.log_file is not None:
        with open(args.log_file, 'w'):
            pass

    studysets_path = spec['studysets']
    studyset_names = spec['studyset_names']
    spike_sorters = spec['spike_sorters']

    ka.set_config(fr='default_readonly')

    print(f'Loading study sets object from: {studysets_path}')
    studysets_obj = ka.load_object(studysets_path)
    if not studysets_obj:
        raise Exception(f'Unable to load: {studysets_path}')
    
    all_study_sets = studysets_obj['StudySets']
    study_sets = []
    for studyset in all_study_sets:
        if studyset['name'] in studyset_names:
            study_sets.append(studyset)
    
    if int(args.parallel) > 0:
        job_handler = hither.ParallelJobHandler(int(args.parallel))
        job_handler_gpu = job_handler
        job_handler_ks = job_handler
    elif args.slurm:
        with open(args.slurm, 'r') as f:
            slurm_config = json.load(f)
        job_handler = hither.SlurmJobHandler(
            working_dir='tmp_slurm',
            **slurm_config['cpu']
        )
        job_handler_gpu = hither.SlurmJobHandler(
            working_dir='tmp_slurm',
            **slurm_config['gpu']
        )
        job_handler_ks = hither.SlurmJobHandler(
            working_dir='tmp_slurm',
            **slurm_config['ks']
        )
    else:
        job_handler = None
        job_handler_gpu = None
        job_handler_ks = None

    with hither.config(
        container='default',
        cache=args.cache,
        force_run=force_run_all,
        job_handler=job_handler,
        log_path=args.log_file
    ), hither.job_queue():
        studies = []
        recordings = []
        for studyset in study_sets:
            studyset_name = studyset['name']
            print(f'================ STUDY SET: {studyset_name}')
            studies0 = studyset['studies']
            if args.test:
                studies0 = studies0[:1]
                studyset['studies'] = studies0
            for study in studies0:
                study['study_set'] = studyset_name
                study_name = study['name']
                print(f'======== STUDY: {study_name}')
                recordings0 = study['recordings']
                if args.test:
                    recordings0 = recordings0[:2]
                    study['recordings'] = recordings0
                for recording in recordings0:
                    recording['study'] = study_name
                    recording['study_set'] = studyset_name
                    recording['firings_true'] = recording['firingsTrue']
                    recordings.append(recording)
                studies.append(study)

        # Download recordings
        for recording in recordings:
            print(f'Downloading recording: {recording["study"]}/{recording["name"]}')
            ka.load_file(recording['directory'] + '/raw.mda')
            ka.load_file(recording['directory'] + '/params.json')
            ka.load_file(recording['directory'] + '/geom.csv')
            ka.load_file(recording['directory'] + '/firings_true.mda')
        
        # Attach results objects
        for recording in recordings:
            recording['results'] = dict()
        
        # Summarize recordings
        for recording in recordings:
            recording_path = recording['directory']
            sorting_true_path = recording['firingsTrue']
            recording['results']['computed-info'] = processing.compute_recording_info.run(
                _label=f'compute-recording-info:{recording["study"]}/{recording["name"]}',
                recording_path=recording_path,
                json_out=hither.File()
            )
            recording['results']['true-units-info'] = processing.compute_units_info.run(
                _label=f'compute-units-info:{recording["study"]}/{recording["name"]}',
                recording_path=recording_path,
                sorting_path=sorting_true_path,
                json_out=hither.File()
            )
        
        # Spike sorting
        for sorter in spike_sorters:
            for recording in recordings:
                if recording['study_set'] in sorter['studysets']:
                    recording_path = recording['directory']
                    sorting_true_path = recording['firingsTrue']

                    algorithm = sorter['processor_name']
                    if not hasattr(sorters, algorithm):
                        raise Exception(f'No such sorting algorithm: {algorithm}')
                    Sorter = getattr(sorters, algorithm)

                    if algorithm in ['ironclust-disable', 'tridesclous']:
                        gpu = True
                        jh = job_handler_gpu
                    elif algorithm in ['kilosort', 'kilosort2']:
                        gpu = True
                        jh = job_handler_ks
                    else:
                        gpu = False
                        jh = job_handler
                    with hither.config(gpu=gpu, force_run=force_run, exception_on_fail=False, cache_failing=cache_failing, rerun_failing=rerun_failing, job_handler=jh, job_timeout=job_timeout):
                        sorting_result = Sorter.run(
                            _label=f'{algorithm}:{recording["study"]}/{recording["name"]}',
                            recording_path=recording['directory'],
                            sorting_out=hither.File()
                        )
                        recording['results']['sorting-' + sorter['name']] = sorting_result
                    recording['results']['comparison-with-truth-' + sorter['name']] = processing.compare_with_truth.run(
                        _label=f'comparison-with-truth:{algorithm}:{recording["study"]}/{recording["name"]}',
                        sorting_path=sorting_result.outputs.sorting_out,
                        sorting_true_path=sorting_true_path,
                        json_out=hither.File()
                    )
                    recording['results']['units-info-' + sorter['name']] = processing.compute_units_info.run(
                        _label=f'units-info:{algorithm}:{recording["study"]}/{recording["name"]}',
                        recording_path=recording_path,
                        sorting_path=sorting_result.outputs.sorting_out,
                        json_out=hither.File()
                    )

    # Assemble all of the results
    print('')
    print('=======================================================')
    print('Assembling results...')
    for recording in recordings:
        print(f'Assembling recording: {recording["study"]}/{recording["name"]}')
        recording['summary'] = dict(
            plots=dict(),
            computed_info=ka.load_object(recording['results']['computed-info'].outputs.json_out._path),
            true_units_info=ka.store_file(recording['results']['true-units-info'].outputs.json_out._path)
        )
    sorting_results = []
    for sorter in spike_sorters:
        for recording in recordings:
            if recording['study_set'] in sorter['studysets']:
                print(f'Assembling sorting: {sorter["processor_name"]} {recording["study"]}/{recording["name"]}')
                sorting_result = recording['results']['sorting-' + sorter['name']]
                comparison_result = recording['results']['comparison-with-truth-' + sorter['name']]
                units_info_result = recording['results']['units-info-' + sorter['name']]
                console_out_str = _console_out_to_str(sorting_result.runtime_info['console_out'])
                console_out_path = ka.store_text(console_out_str)
                sr = dict(
                    recording=recording,
                    sorter=sorter,
                    firings_true=recording['directory'] + '/firings_true.mda',
                    processor_name=sorter['processor_name'],
                    processor_version=sorting_result.version,
                    sorting_parameters=sorter['params'],
                    execution_stats=dict(
                        start_time=sorting_result.runtime_info['start_time'],
                        end_time=sorting_result.runtime_info['end_time'],
                        elapsed_sec=sorting_result.runtime_info['end_time'] - sorting_result.runtime_info['start_time'],
                        reported_elapsed_sec=_parse_spikeforest_runtime(console_out_str),
                        retcode=0 if sorting_result.success else -1,
                        timed_out=sorting_result.runtime_info.get('timed_out', False)
                    ),
                    container=sorting_result.container,
                    console_out=console_out_path
                )
                if sorting_result.success:
                    sr['firings'] = ka.store_file(sorting_result.outputs.sorting_out._path)
                    sr['comparison_with_truth'] = dict(
                        json=ka.store_file(comparison_result.outputs.json_out._path)
                    )
                    sr['sorted_units_info'] = ka.store_file(units_info_result.outputs.json_out._path)
                else:
                    sr['firings'] = None
                    sr['comparison_with_truth'] = None
                    sr['sorted_units_info'] = None
                sorting_results.append(sr)
    
    # Delete results from recordings
    for recording in recordings:
        del recording['results']

    # Aggregate sorting results
    print('')
    print('=======================================================')
    print('Aggregating sorting results...')
    aggregated_sorting_results = aggregate_sorting_results(studies, recordings, sorting_results)

    # Show output summary
    for sr in aggregated_sorting_results['study_sorting_results']:
        study_name = sr['study']
        sorter_name = sr['sorter']
        n1 = np.array(sr['num_matches'])
        n2 = np.array(sr['num_false_positives'])
        n3 = np.array(sr['num_false_negatives'])
        accuracies = n1 / (n1 + n2 + n3)
        avg_accuracy = np.mean(accuracies)
        txt = 'STUDY: {}, SORTER: {}, AVG ACCURACY: {}'.format(study_name, sorter_name, avg_accuracy)
        print(txt)
    
    output_object = dict(
        studies=studies,
        recordings=recordings,
        study_sets=study_sets,
        sorting_results=sorting_results,
        aggregated_sorting_results=ka.store_object(aggregated_sorting_results, basename='aggregated_sorting_results.json')
    )

    print(f'Writing output to {args.output}...')
    with open(args.output, 'w') as f:
        json.dump(output_object, f, indent=4)
    print('Done.')
Ejemplo n.º 14
0
def main():
    from mountaintools import client as mt

    parser = argparse.ArgumentParser(
        description=
        'Generate unit detail data (including spikesprays) for website')
    parser.add_argument('analysis_path',
                        help='assembled analysis file (output.json)')
    parser.add_argument(
        '--studysets',
        help='Comma-separated list of study set names to include',
        required=False,
        default=None)
    parser.add_argument('--force-run',
                        help='Force rerunning of processing',
                        action='store_true')
    parser.add_argument(
        '--force-run-all',
        help='Force rerunning of processing including filtering',
        action='store_true')
    parser.add_argument('--parallel',
                        help='Optional number of parallel jobs',
                        required=False,
                        default='0')
    parser.add_argument('--slurm',
                        help='Path to slurm config file',
                        required=False,
                        default=None)
    parser.add_argument('--cache',
                        help='The cache database to use',
                        required=False,
                        default=None)
    parser.add_argument('--job-timeout',
                        help='Timeout for processing jobs',
                        required=False,
                        default=600)
    parser.add_argument('--log-file',
                        help='Log file for analysis progress',
                        required=False,
                        default=None)
    parser.add_argument(
        '--force-regenerate',
        help=
        'Whether to force regenerating spike sprays (for when code has changed)',
        action='store_true')
    parser.add_argument('--test',
                        help='Whether to just test by running only 1',
                        action='store_true')

    args = parser.parse_args()

    mt.configDownloadFrom(['spikeforest.kbucket'])

    with open(args.analysis_path, 'r') as f:
        analysis = json.load(f)

    if args.studysets is not None:
        studyset_names = args.studysets.split(',')
        print('Using study sets: ', studyset_names)
    else:
        studyset_names = None

    study_sets = analysis['StudySets']
    sorting_results = analysis['SortingResults']

    studies_to_include = []
    for ss in study_sets:
        if (studyset_names is None) or (ss['name'] in studyset_names):
            for study in ss['studies']:
                studies_to_include.append(study['name'])

    print('Including studies:', studies_to_include)

    print('Determining sorting results to process ({} total)...'.format(
        len(sorting_results)))
    sorting_results_to_process = []
    sorting_results_to_consider = []
    for sr in sorting_results:
        study_name = sr['studyName']
        if study_name in studies_to_include:
            if 'firings' in sr:
                if sr.get('comparisonWithTruth', None) is not None:
                    sorting_results_to_consider.append(sr)
                    key = dict(name='unit-details-v0.1.0',
                               recording_directory=sr['recordingDirectory'],
                               firings_true=sr['firingsTrue'],
                               firings=sr['firings'])
                    val = mt.getValue(key=key, collection='spikeforest')
                    if (not val) or (args.force_regenerate):
                        sr['key'] = key
                        sorting_results_to_process.append(sr)
    if args.test and len(sorting_results_to_process) > 0:
        sorting_results_to_process = [sorting_results_to_process[0]]

    print('Need to process {} of {} sorting results'.format(
        len(sorting_results_to_process), len(sorting_results_to_consider)))

    recording_directories_to_process = sorted(
        list(
            set([
                sr['recordingDirectory'] for sr in sorting_results_to_process
            ])))
    print('{} recording directories to process'.format(
        len(recording_directories_to_process)))

    if int(args.parallel) > 0:
        job_handler = hither.ParallelJobHandler(int(args.parallel))
    elif args.slurm:
        with open(args.slurm, 'r') as f:
            slurm_config = json.load(f)
        job_handler = hither.SlurmJobHandler(working_dir='tmp_slurm',
                                             **slurm_config['cpu'])
    else:
        job_handler = None

    print('Filtering recordings...')
    filter_results = []
    with hither.config(container='default',
                       cache=args.cache,
                       force_run=args.force_run_all,
                       job_handler=job_handler,
                       log_path=args.log_file,
                       exception_on_fail=True,
                       cache_failing=False,
                       rerun_failing=True,
                       job_timeout=args.job_timeout), hither.job_queue():
        for recdir in recording_directories_to_process:
            result = filter_recording.run(recording_directory=recdir,
                                          timeseries_out=hither.File())
            filter_results.append(result)
    filtered_timeseries_by_recdir = dict()
    for i, recdir in enumerate(recording_directories_to_process):
        result0 = filter_results[i]
        if not result0.success:
            raise Exception(
                'Problem computing filtered timeseries for recording: {}'.
                format(recdir))
        filtered_timeseries_by_recdir[
            recdir] = result0.outputs.timeseries_out._path

    print('Creating spike sprays...')
    with hither.config(container='default',
                       cache=args.cache,
                       force_run=args.force_run or args.force_run_all,
                       job_handler=job_handler,
                       log_path=args.log_file,
                       exception_on_fail=True,
                       cache_failing=False,
                       rerun_failing=True,
                       job_timeout=args.job_timeout), hither.job_queue():
        for sr in sorting_results_to_process:
            recdir = sr['recordingDirectory']
            study_name = sr['studyName']
            rec_name = sr['recordingName']
            sorter_name = sr['sorterName']

            print('====== COMPUTING {}/{}/{}'.format(study_name, rec_name,
                                                     sorter_name))

            cwt = ka.load_object(path=sr['comparisonWithTruth']['json'])

            filtered_timeseries = filtered_timeseries_by_recdir[recdir]

            spike_spray_results = []
            list0 = list(cwt.values())
            for _, unit in enumerate(list0):
                result = create_spike_sprays.run(
                    recording_directory=recdir,
                    filtered_timeseries=filtered_timeseries,
                    firings_true=os.path.join(recdir, 'firings_true.mda'),
                    firings_sorted=sr['firings'],
                    unit_id_true=unit['unit_id'],
                    unit_id_sorted=unit['best_unit'],
                    json_out=hither.File())
                setattr(result, 'unit', unit)
                spike_spray_results.append(result)
            sr['spike_spray_results'] = spike_spray_results

    for sr in sorting_results_to_process:
        recdir = sr['recordingDirectory']
        study_name = sr['studyName']
        rec_name = sr['recordingName']
        sorter_name = sr['sorterName']

        print('====== SAVING {}/{}/{}'.format(study_name, rec_name,
                                              sorter_name))
        spike_spray_results = sr['spike_spray_results']
        key = sr['key']

        unit_details = []
        ok = True
        for i, result in enumerate(spike_spray_results):
            if not result.success:
                print(
                    'WARNING: Error creating spike sprays for {}/{}/{}'.format(
                        study_name, rec_name, sorter_name))
                ok = False
                break
            ssobj = ka.load_object(result.outputs.json_out._path)
            if ssobj is None:
                raise Exception('Problem loading spikespray object output.')
            address = mt.saveObject(object=ssobj,
                                    upload_to='spikeforest.kbucket')
            unit = getattr(result, 'unit')
            unit_details.append(
                dict(studyName=study_name,
                     recordingName=rec_name,
                     sorterName=sorter_name,
                     trueUnitId=unit['unit_id'],
                     sortedUnitId=unit['best_unit'],
                     spikeSprayUrl=mt.findFile(
                         path=address,
                         remote_only=True,
                         download_from='spikeforest.kbucket')))

        if ok:
            mt.saveObject(collection='spikeforest',
                          key=key,
                          object=unit_details,
                          upload_to='spikeforest.public')
Ejemplo n.º 15
0
def main():
    example3()

    with hither.config(container='docker://python:3.7'):
        example3()
Ejemplo n.º 16
0
#!/usr/bin/env python

from spikeforest2 import sorters
import hither
import kachery as ka

recording_path = 'sha1://961f4a641af64dded4821610189f808f0192de4d/SYNTH_MEAREC_TETRODE/synth_mearec_tetrode_noise10_K10_C4/002_synth.json'

with ka.config(fr='default_readonly'):
    with hither.config(container='default', gpu=False, job_timeout=5, exception_on_fail=False):
        result = sorters.mountainsort4.run(
            recording_path=recording_path,
            sorting_out=hither.File()
        )

print('Status: ', result.status)
print('Success: ', result.success)
print('Timed out:', result.runtime_info['timed_out'])

assert result.status == 'error'
assert result.success is False
assert result.runtime_info['timed_out'] is True

print('Passed.')
Ejemplo n.º 17
0
def example1():
    with hither.config(container='default'):
        result = hello_hither_scipy.run(n=20)
        print(result.retval)