def main():

    parser = argparse.ArgumentParser(
        description='update spike-front analysis history')
    parser.add_argument(
        '--delete-and-init',
        help='Use this only the first time. Will wipe out the history.',
        action='store_true')
    parser.add_argument('--show',
                        help='Show the analysis history.',
                        action='store_true')

    args = parser.parse_args()

    history_path = 'key://pairio/spikeforest/spike-front-analysis-history.json'
    results_path = 'key://pairio/spikeforest/spike-front-results.json'

    mt.configDownloadFrom(['spikeforest.kbucket', 'spikeforest.public'])
    if args.delete_and_init:
        print('Initializing analysis history...')
        mt.createSnapshot(mt.saveObject(object=dict(analyses=[])),
                          dest_path=history_path,
                          upload_to='spikeforest.public')
        print('Done.')
        return

    print('Loading history...')
    history = mt.loadObject(path=history_path)
    assert history, 'Unable to load history'

    if args.show:
        for aa in history['analyses']:
            print(
                '==============================================================='
            )
            print('ANALYSIS: {}'.format(aa['General']['dateUpdated']))
            print('PATH: {}'.format(aa['path']))
            print(json.dumps(aa['General'], indent=4))
            print('')
        return

    print('Loading current results...')
    spike_front_results = mt.loadObject(path=results_path)
    assert spike_front_results, 'Unable to load current results'

    sha1_url = mt.saveObject(object=spike_front_results,
                             basename='analysis.json')
    for aa in history['analyses']:
        if aa['path'] == sha1_url:
            print('Analysis already stored in history.')
            return

    history['analyses'].append(
        dict(path=sha1_url, General=spike_front_results['General'][0]))

    print('Saving updated history to {}'.format(history_path))
    mt.createSnapshot(mt.saveObject(object=history),
                      dest_path=history_path,
                      upload_to='spikeforest.public')
    print('Done.')
 def javascript_state_changed(self, prev_state, state):
     self._set_status('running', 'Running NWBFile')
     mt.configDownloadFrom(state.get('download_from', []))
     path = state.get('path', None)
     if path:
         if path.endswith('.nwb'):
             self._set_status('running',
                              'Reading nwb file: {}'.format(path))
             obj = nwb_to_dict(path,
                               use_cache=False,
                               exclude_data=True,
                               verbose=False)
             self._set_status('running',
                              'Finished nwb file: {}'.format(path))
         else:
             self._set_status('running',
                              'Realizing object: {}'.format(path))
             obj = mt.loadObject(path=path)
         if not obj:
             self._set_error('Unable to realize object: {}'.format(path))
             return
         self.set_python_state(
             dict(status='finished',
                  status_message='finished loading: {}'.format(path),
                  object=obj))
     else:
         self._set_error('Missing path')
Exemple #3
0
def _load_analysis_context(path):
    obj = mt.loadObject(path=path)
    if not obj:
        print('Unable to load file: ' + path, file=sys.stderr)
        return None
    context = AnalysisContext(obj=obj)
    return context
def do_sorting_test(sorting_processor,
                    params,
                    recording_dir,
                    assert_avg_accuracy,
                    container='default'):
    mt.configDownloadFrom('spikeforest.kbucket')

    recdir = recording_dir
    mt.createSnapshot(path=recdir, download_recursive=True)
    sorting = sorting_processor.execute(recording_dir=recdir,
                                        firings_out={'ext': '.mda'},
                                        **params,
                                        _container=container,
                                        _force_run=True)

    comparison = sa.GenSortingComparisonTable.execute(
        firings=sorting.outputs['firings_out'],
        firings_true=recdir + '/firings_true.mda',
        units_true=[],
        json_out={'ext': '.json'},
        html_out={'ext': '.html'},
        _container=None,
        _force_run=True)

    X = mt.loadObject(path=comparison.outputs['json_out'])
    accuracies = [float(a['accuracy']) for a in X.values()]
    avg_accuracy = np.mean(accuracies)

    print('Average accuracy: {}'.format(avg_accuracy))

    assert (avg_accuracy >= assert_avg_accuracy)
 def compute_score(self, sorting_extractor):
     if self.metric != 'spikeforest':
         comparison = sc.compare_sorter_to_ground_truth(self.gt_se,
                                                           sorting_extractor, exhaustive_gt=True)
         d_results = comparison.get_performance(method='pooled_with_average', output='dict')
         print('results')
         print(d_results)
         if self.metric == 'accuracy':
             score = d_results['accuracy']
         if self.metric == 'precision':
             score = d_results['precision']
         if self.metric == 'recall':
             score = d_results['recall']
         if self.metric == 'f1':
             print('comparison:')
             print(d_results)
             if (d_results['precision']+d_results['recall']) > 0:
                 score = 2 * d_results['precision'] * d_results['recall'] / (d_results['precision']+d_results['recall'])
             else:
                 score = 0
         del comparison
     else:
         tmp_dir = 'test_outputs_spikeforest'
         SFMdaSortingExtractor.write_sorting(sorting=sorting_extractor, save_path=os.path.join(tmp_dir,'firings.mda'))
         print('Compare with ground truth...')
         sa.GenSortingComparisonTable.execute(firings=os.path.join(tmp_dir,'firings.mda'),
                                              firings_true=os.path.join(tmp_dir,'firings_true.mda'),
                                              units_true=self.true_units_above,  # use all units
                                              json_out=os.path.join(tmp_dir,'comparison.json'),
                                              html_out=os.path.join(tmp_dir,'comparison.html'),
                                              _container=None)
         comparison = mt.loadObject(path=os.path.join(tmp_dir,'comparison.json'))
         score = np.mean([float(u['accuracy']) for u in comparison.values()])
     return -score
    def javascript_state_changed(self, prev_state, state):
        self._set_status('running', 'Running')

        sorting0 = state.get('sorting', None)
        if not sorting0:
            self._set_error('Missing: sorting')
            return
        try:
            self._sorting = AutoSortingExtractor(sorting0)
        except Exception as err:
            traceback.print_exc()
            self._set_error('Problem initiating sorting: {}'.format(err))
            return

        max_samples = state.get('max_samples')
        max_dt_msec = state.get('max_dt_msec')
        bin_size_msec = state.get('bin_size_msec')
        if not max_dt_msec:
            return

        result = ComputeAutocorrelograms.execute(sorting=self._sorting,
                                                 max_samples=max_samples,
                                                 bin_size_msec=bin_size_msec,
                                                 max_dt_msec=max_dt_msec,
                                                 json_out=dict(ext='.json'))
        if result.retcode != 0:
            self._set_error('Error computing autocorrelograms.')
            return

        output = mt.loadObject(path=result.outputs['json_out'])
        self._set_state(status='finished', output=output)
Exemple #7
0
def assembleBatchResults(*, batch_name):
    batch = mt.loadObject(key=dict(batch_name=batch_name))
    jobs = batch['jobs']

    print('Assembling results for batch {} with {} jobs'.format(
        batch_name, len(jobs)))
    job_results = []
    for job in jobs:
        print('ASSEMBLING: ' + job['label'])
        result = mt.loadObject(key=job)
        if not result:
            raise Exception('Unable to load object for job: ' + job['label'])
        job_results.append(dict(job=job, result=result))
    print('Saving results...')
    mt.saveObject(key=dict(name='job_results', batch_name=batch_name),
                  object=dict(job_results=job_results))
    print('Done.')
Exemple #8
0
 def __init__(self, *, context, opts=None):
     vd.Component.__init__(self)
     self._context = context
     self._size = (100, 100)
     if not context.comparisonWithTruthPath():
         self._object = None
     else:
         self._object = mt.loadObject(path=context.comparisonWithTruthPath())
Exemple #9
0
def main():
    parser = argparse.ArgumentParser(
        description=help_txt, formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('--output_dir',
                        help='The output directory for saving the files.')
    parser.add_argument('--download-from',
                        help='The output directory for saving the files.',
                        required=False,
                        default='spikeforest.kbucket')
    parser.add_argument('--key_path', help='Key path to retrieve data from')

    args = parser.parse_args()

    output_dir = args.output_dir

    if os.path.exists(output_dir):
        raise Exception(
            'Output directory already exists: {}'.format(output_dir))

    os.mkdir(output_dir)

    mt.configDownloadFrom(args.download_from)

    print('Loading spike-front results object...')
    obj = mt.loadObject(path=args.key_path)

    StudySets = obj['StudySets']
    SortingResults = obj['SortingResults']
    Sorters = obj['Sorters']
    Algorithms = obj['Algorithms']
    StudyAnalysisResults = obj['StudyAnalysisResults']
    General = obj['General']

    print('Saving {} study sets to {}/StudySets.json'.format(
        len(StudySets), output_dir))
    mt.saveObject(object=StudySets, dest_path=output_dir + '/StudySets.json')

    print('Saving {} sorting results to {}/SortingResults.json'.format(
        len(SortingResults), output_dir))
    mt.saveObject(object=SortingResults,
                  dest_path=output_dir + '/SortingResults.json')

    print('Saving {} sorters to {}/Sorters.json'.format(
        len(Sorters), output_dir))
    mt.saveObject(object=Sorters, dest_path=output_dir + '/Sorters.json')

    print('Saving {} algorithms to {}/Algorithms.json'.format(
        len(Algorithms), output_dir))
    mt.saveObject(object=Algorithms, dest_path=output_dir + '/Algorithms.json')

    print('Saving {} study analysis results to {}/StudyAnalysisResults.json'.
          format(len(StudySets), output_dir))
    mt.saveObject(object=StudyAnalysisResults,
                  dest_path=output_dir + '/StudyAnalysisResults.json')

    print('Saving general info to {}/General.json'.format(output_dir))
    mt.saveObject(object=General, dest_path=output_dir + '/General.json')
Exemple #10
0
def main():
    parser = argparse.ArgumentParser(description='Show the publicly downloadable studies of SpikeForest')
    parser.add_argument('--group_names', help='Comma-separated list of recording group names.', required=False, default=None)

    args = parser.parse_args()

    mt.configDownloadFrom(['spikeforest.public'])

    if args.group_names is not None:
        group_names = args.output_ids.split(',')
    else:
        group_names = [
            'paired_boyden32c',
            'paired_crcns',
            'paired_mea64c',
            'paired_kampff',
            'synth_bionet',
            'synth_magland',
            'manual_franklab',
            'synth_mearec_neuronexus',
            'synth_mearec_tetrode',
            'synth_visapy',
            'hybrid_janelia'
        ]
    print('Using group names: ', group_names)

    studies = []
    study_sets = []
    recordings = []
    for group_name in group_names:
        print('RECORDING GROUP: {}'.format(group_name))
        output_path = ('key://pairio/spikeforest/spikeforest_recording_group.{}.json').format(group_name)
        obj = mt.loadObject(path=output_path)
        if obj:
            studies = studies + obj['studies']
            study_sets = study_sets + obj.get('study_sets', [])
            recordings = recordings + obj['recordings']
            study_sets_by_study = dict()
            for study in obj['studies']:
                study_sets_by_study[study['name']] = study['study_set']
            for rec in obj['recordings']:
                if rec.get('public', False):
                    study_set = study_sets_by_study.get(rec.get('study', ''), '')
                    print('{}/{}/{}: {}'.format(study_set, rec.get('study', ''), rec.get('name', ''), rec.get('directory', '')))
        else:
            print('WARNING: unable to load object: ' + output_path)

    print('')
    print('ALL GROUPS')
    study_sets_by_study = dict()
    for study in studies:
        study_sets_by_study[study['name']] = study['study_set']
    for rec in recordings:
        if rec.get('public', False):
            study_set = study_sets_by_study.get(rec.get('study', ''), '')
            print('- {}/{}/{}: `{}`'.format(study_set, rec.get('study', ''), rec.get('name', ''), rec.get('directory', '')))
Exemple #11
0
def clearBatch(*, batch_name, test_one=False):
    batch = mt.loadObject(key=dict(batch_name=batch_name))
    jobs = batch['jobs']

    if test_one and (len(jobs) > 0):
        jobs = [jobs[0]]

    setBatchStatus(batch_name=batch_name, status='clearing_batch')
    _clear_job_results(jobs=jobs, incomplete_only=False)
    setBatchStatus(batch_name=batch_name, status='finished_clearing_batch')
Exemple #12
0
    def javascript_state_changed(self, prev_state, state):
        self.set_python_state(dict(status='running', status_message='Running'))
        mt.configDownloadFrom(state.get('download_from', []))

        max_samples = state.get('max_samples')
        max_dt_msec = state.get('max_dt_msec')
        bin_size_msec = state.get('bin_size_msec')
        if not max_dt_msec:
            return

        firings_path = state.get('firingsPath', None)
        if not firings_path:
            self.set_python_state(dict(
                status='error',
                status_message='No firingsPath provided'
            ))
            return

        samplerate = state.get('samplerate', None)
        if not samplerate:
            self.set_python_state(dict(
                status='error',
                status_message='No samplerate provided'
            ))
            return

        self.set_python_state(dict(status_message='Realizing file: {}'.format(firings_path)))
        firings_path2 = mt.realizeFile(firings_path)
        if not firings_path2:
            self.set_python_state(dict(
                status='error',
                status_message='Unable to realize file: {}'.format(firings_path)
            ))
            return

        result = ComputeAutocorrelograms.execute(
            firings_path=firings_path2,
            samplerate=samplerate,
            max_samples=max_samples,
            bin_size_msec=bin_size_msec,
            max_dt_msec=max_dt_msec,
            json_out=dict(ext='.json')
        )
        if result.retcode != 0:
            self.set_python_state(dict(
                status='error',
                status_message='Error computing autocorrelogram.'
            ))
            return

        output = mt.loadObject(path=result.outputs['json_out'])
        self.set_python_state(dict(
            status='finished',
            output=output
        ))
Exemple #13
0
def runBatch(*, batch_name, test_one=False):
    print('Loading batch object...')
    batch = mt.loadObject(key=dict(batch_name=batch_name))
    jobs = batch['jobs']

    if test_one and (len(jobs) > 0):
        jobs = [jobs[0]]

    print('Running batch with {} jobs...'.format(len(jobs)))
    for job in jobs:
        _run_job(job)
Exemple #14
0
 def _open_job(self, job_index):
     job0 = self._jobs[job_index]
     if 'result' not in job0:
         job_result_key = dict(name='compute_resource_batch_job_results',
                               batch_id=self._batch_id)
         result0 = mt.loadObject(key=job_result_key, subkey=str(job_index))
         if result0:
             job0['result'] = result0
     self._job_view.setJob(job0)
     self._list_mode = False
     self.refresh()
Exemple #15
0
def main():
    path = 'sha1dir://8516cc54587e0c5ddd0709154e7f609b9b7884b4'
    mt.configDownloadFrom('spikeforest.public')
    X = mt.readDir(path)
    for study_set_name, d in X['dirs'].items():
        for study_name, d2 in d['dirs'].items():
            for recording_name, d3 in d2['dirs'].items():
                x = mt.loadObject(path=path + '/' + study_set_name + '/' +
                                  study_name + '/' + recording_name +
                                  '.runtime_info.json')
                print('{}/{}/{}\n{} sec\n'.format(study_set_name, study_name,
                                                  recording_name,
                                                  x['elapsed_sec']))
Exemple #16
0
def h5_to_dict(fname, *, upload_to=None, use_cache=False):
    if use_cache:
        result = H5ToDict.execute(h5_in=fname,
                                  upload_to=upload_to or '',
                                  json_out={'ext': '.json'})
        if result.retcode != 0:
            raise Exception('Problem running H5ToDict.')
        return mt.loadObject(path=result.outputs['json_out'])

    fname = mt.realizeFile(path=fname)
    opts = dict(upload_to=upload_to)
    with h5py.File(fname, 'r') as f:
        opts['file'] = f
        return _h5_to_dict(f, opts=opts, name=None)
Exemple #17
0
def prepareBatch(*, batch_name, test_one=False):
    batch = mt.loadObject(key=dict(batch_name=batch_name))
    jobs = batch['jobs']

    if test_one and (len(jobs) > 0):
        jobs = [jobs[0]]

    setBatchStatus(batch_name=batch_name, status='preparing_batch')
    _clear_job_results(jobs=jobs, incomplete_only=True)

    setBatchStatus(batch_name=batch_name, status='downloading_recordings')
    _download_recordings(jobs=jobs)

    setBatchStatus(batch_name=batch_name, status='finished_preparing_batch')
Exemple #18
0
 def _on_group_changed(self):
     output_id = self._SEL_group.value()
     if not output_id:
         return
     a = mt.loadObject(key=dict(name='spikeforest_results'),
                       subkey=output_id)
     #print('_on_group_changed: ', a)
     # key=dict(name='spikeforest_results', output_id='spikeforest_test2'))
     SF = sf.SFData()
     SF.loadStudies(a['studies'])
     SF.loadRecordings2(a['recordings'])
     self._SF = SF
     self._SEL_study.setOptions(SF.studyNames())
     self._on_study_changed(value=self._SEL_study.value())
     self.refresh()
Exemple #19
0
    def javascript_state_changed(self, prev_state, state):
        self.set_python_state(dict(status='running', status_message='Running'))

        # get javascript state
        download_from = state.get('download_from', [])
        path = state.get('path', None)
        name = state.get('name', None)

        if path and name:
            mt.configDownloadFrom(download_from)
            if path.endswith('.nwb'):
                self.set_python_state(
                    dict(status_message='Realizing object from nwb file: {}'.
                         format(path)))
                obj = nwb_to_dict(path, use_cache=True)
            else:
                self.set_python_state(
                    dict(status_message='Realizing object: {}'.format(path)))
                obj = mt.loadObject(path=path)
            if not obj:
                self.set_python_state(
                    dict(status='error',
                         status_message='Unable to realize object: {}'.format(
                             path)))
                return
            datasets = obj['general']['subject']['cortical_surfaces'][name][
                '_datasets']
            faces0 = np.load(mt.realizeFile(datasets['faces']['_data']))
            vertices = np.load(mt.realizeFile(datasets['vertices']['_data'])).T

            # there's a better way to do the following
            # (need to get it into a single vector format)
            faces = []
            for j in range(faces0.shape[0]):
                # 3 = #vertices in polygon (assume a triangulation)
                faces.extend([3, faces0[j, 0], faces0[j, 1], faces0[j, 2]])
            faces = np.array(faces)

            # return this python state to the javascript
            self.set_python_state(
                dict(faces=faces,
                     vertices=vertices,
                     status='finished',
                     status_message='Done.'))
        else:
            self.set_python_state(
                dict(status='error',
                     status_message='Missing path and/or name'))
Exemple #20
0
def _load_spikefront_context(path):
    obj = mt.loadObject(path=path)
    if not obj:
        print('Unable to load file: ' + path, file=sys.stderr)
        return None
    context = SpikeFrontContext(StudySets=obj.get("StudySets", []),
                                Recordings=obj.get("Recordings", []),
                                TrueUnits=obj.get("TrueUnits", []),
                                UnitResults=obj.get("UnitResults", []),
                                SortingResults=obj.get("SortingResults", []),
                                Sorters=obj.get("Sorters", []),
                                Studies=obj.get("Studies", []),
                                Algorithms=obj.get("Algorithms", []),
                                StudyAnalysisResults=obj.get(
                                    "StudyAnalysisResults", []))
    return context
Exemple #21
0
 def javascript_state_changed(self, prev_state, state):
     self.set_python_state(dict(status='running', status_message='Running'))
     mt.configDownloadFrom(state.get('download_from', []))
     path = state.get('path', None)
     if path:
         self.set_python_state(
             dict(status_message='Realizing object: {}'.format(path)))
         obj = mt.loadObject(path=path)
         if not obj:
             self.set_python_state(
                 dict(status='error',
                      status_message='Unable to realize object: {}'.format(
                          path)))
             return
         state['object'] = obj
         state['status'] = 'finished'
         self.set_python_state(state)
Exemple #22
0
 def prepareView(context, opts):
     sorting_context = context
     recording_context = context.recordingContext()
     try:
         recording_context.initialize()
         sorting_context.initialize()
         print('***** Preparing efficient access recording extractor...')
         earx = EfficientAccessRecordingExtractor(recording=recording_context.recordingExtractor())
         print('***** computing units info...')
         info0 = mt.loadObject(path=ComputeUnitsInfo.execute(recording=earx, sorting=sorting_context.sortingExtractor(), json_out=True).outputs['json_out'])
         print('*****')
     except:
         traceback.print_exc()
         raise
     return dict(
         units_info=info0
     )
Exemple #23
0
def _load_spikeforest_context(path):
    if mt.isFile(path):
        obj = mt.loadObject(path=path)
        if not obj:
            print('Unable to load file: ' + path, file=sys.stderr)
            return None
    else:
        obj = _make_obj_from_dir(path)
        if not obj:
            print('Unable to make object from path: ' + path)
            return None
    context = SpikeForestContext(
        studies=obj.get('studies', []),
        recordings=obj.get('recordings', []),
        sorting_results=obj.get('sorting_results', []),
        aggregated_sorting_results=obj.get('aggregated_sorting_results', None))
    return context
    def javascript_state_changed(self, prev_state, state):
        self._set_status('running', 'Running Analysis')

        path = state.get('path', None)
        if not path:
            self._set_error('Missing path')
            return

        self._set_status('running', 'Loading object: {}'.format(path))
        obj = mt.loadObject(path=path,
                            download_from=state.get('download_from', None))
        if not obj:
            self._set_error('Unable to load object: {}'.format(path))

        self._set_state(object=obj)

        self._set_status('finished', 'Finished Analysis')
Exemple #25
0
    def __init__(self, output_id):
        vd.Component.__init__(self)

        self._output_id = output_id

        a = mt.loadObject(key=dict(name='spikeforest_results'),
                          subkey=output_id)
        if not a:
            print('ERROR: unable to open results: ' + output_id)
            return

        if ('recordings' not in a) or ('studies'
                                       not in a) or ('sorting_results'
                                                     not in a):
            print('ERROR: problem with output: ' + output_id)
            return

        studies = a['studies']
        recordings = a['recordings']
        sorting_results = a['sorting_results']

        SF = sf.SFData()
        SF.loadStudies(studies)
        SF.loadRecordings2(recordings)
        SF.loadSortingResults(sorting_results)

        # sorter_names=[]
        # for SR in sorting_results:
        #     sorter_names.append(SR['sorter']['name'])
        # sorter_names=list(set(sorter_names))
        # sorter_names.sort()

        self._SF_data = SF

        self._accuracy_threshold_input = vd.components.LineEdit(
            value=0.8, dtype=float, style=dict(width='70px'))
        self._update_button = vd.components.Button(onclick=self._on_update,
                                                   class_='button',
                                                   label='Update')
        self._study_sorter_fig = StudySorterFigure(SF)
        self._study_sorter_table = vd.div()  # dummy

        vd.devel.loadBootstrap()

        self._update_accuracy_table()
Exemple #26
0
    def javascript_state_changed(self, prev_state, state):
        self.set_python_state(dict(status='running', status_message='Running'))
        mt.configDownloadFrom(state.get('download_from', []))
        nwb_path = state.get('nwb_path', None)
        downsample_factor = state.get('downsample_factor', 1)
        if nwb_path:
            if nwb_path.endswith('.nwb'):
                self.set_python_state(
                    dict(status_message='Realizing object from nwb file: {}'.
                         format(nwb_path)))
                obj = h5_to_dict(nwb_path, use_cache=True)
            else:
                self.set_python_state(
                    dict(status_message='Realizing object: {}'.format(
                        nwb_path)))
                obj = mt.loadObject(path=nwb_path)
            if not obj:
                self.set_python_state(
                    dict(status='error',
                         status_message='Unable to realize object: {}'.format(
                             nwb_path)))
                return
            try:
                positions_path = obj['processing']['Behavior']['Position'][
                    'Position']['_datasets']['data']['_data']
            except:
                self.set_python_state(
                    dict(status='error',
                         status_message=
                         'Problem extracting behavior positions in file: {}'.
                         format(nwb_path)))
                return
            positions = np.load(mt.realizeFile(path=positions_path))
            positions = positions[::downsample_factor, :]

            self.set_python_state(
                dict(status_message='Finished loading positions'))
            state['positions'] = positions
            state['status'] = 'finished'
            self.set_python_state(state)
        else:
            self.set_python_state(
                dict(status='error',
                     status_message='Missing in state: nwb_path'))
Exemple #27
0
def install_waveclus(repo, commit):
    spikeforest_alg_install_path = get_install_path()
    key = dict(
        alg='waveclus',
        repo=repo,
        commit=commit
    )
    source_path = spikeforest_alg_install_path + '/waveclus_' + commit
    if os.path.exists(source_path):
        # The dir hash method does not seem to be working for some reason here
        # hash0 = mt.computeDirHash(source_path)
        # if hash0 == mt.getValue(key=key):
        #     print('waveclus is already auto-installed.')
        #     return source_path

        a = mt.loadObject(path=source_path + '/spikeforest.json')
        if a:
            if mt.sha1OfObject(a) == mt.sha1OfObject(key):
                print('waveclus is already auto-installed.')
                return source_path

        print('Removing directory: {}'.format(source_path))
        shutil.rmtree(source_path)

    script = """
    #!/bin/bash
    set -e

    git clone {repo} {source_path}
    cd {source_path}
    git checkout {commit}
    """.format(repo=repo, commit=commit, source_path=source_path)
    ss = mlpr.ShellScript(script=script)
    ss.start()
    retcode = ss.wait()
    if retcode != 0:
        raise Exception('Install script returned a non-zero exit code/')

    # The dir hash method does not seem to be working for some reason here
    # hash0 = mt.computeDirHash(source_path)
    # mt.setValue(key=key, value=hash0)
    mt.saveObject(object=key, dest_path=source_path + '/spikeforest.json')

    return source_path
Exemple #28
0
 def _on_output_id_changed(self, value):
     output_id = self._SEL_output_id.value()
     if not output_id:
         return
     key = dict(
         name='spikeforest_results',
         output_id=output_id
     )
     a = mt.loadObject(key=key)
     if a is None:
         raise Exception(
             'Unable to load spikeforest result: {}'.format(output_id))
     SF = sf.SFData()
     SF.loadStudies(a['studies'])
     SF.loadRecordings2(a['recordings'])
     SF.loadSortingResults(a['sorting_results'])
     self._SF = SF
     self._SEL_study.setOptions(SF.studyNames())
     self._on_study_changed(value=self._SEL_study.value())
Exemple #29
0
def do_prepare(recording_group, study_name):
    print(recording_group, study_name)
    X = mt.loadObject(
        path="key://pairio/spikeforest/spikeforest_recording_group.{}.json".
        format(recording_group))
    studies = [y for y in X['studies'] if (y['name'] == study_name)]
    recordings = [y for y in X['recordings'] if y['study'] == study_name]
    recordings = recordings[0:1]
    study_sets = X['study_sets']

    Y = dict(studies=studies, recordings=recordings, study_sets=study_sets)
    address = mt.saveObject(object=Y)
    assert address is not None
    dest_path = 'key://pairio/spikeforest/spikeforest_recording_group.test_{}.json'.format(
        recording_group)
    print(dest_path)
    mt.createSnapshot(path=address,
                      upload_to='spikeforest.kbucket',
                      dest_path=dest_path)
Exemple #30
0
def do_sorting_test(
        sorting_processor,
        params,
        recording_dir,
        container='default',
        force_run=True,
        _keep_temp_files=False
    ):
    mt.configDownloadFrom(['spikeforest.kbucket', 'spikeforest.public'])

    recdir = recording_dir
    mt.createSnapshot(path=recdir, download_recursive=True)
    timer = time.time()
    sorting = sorting_processor.execute(
        recording_dir=recdir,
        firings_out={'ext': '.mda'},
        **params,
        _container=container,
        _force_run=force_run,
        _keep_temp_files=_keep_temp_files
    )
    elapsed = time.time() - timer
    print('################ ELAPSED for sorting (sec): {}'.format(elapsed))

    timer = time.time()
    comparison = sa.GenSortingComparisonTable.execute(
        firings=sorting.outputs['firings_out'],
        firings_true=recdir + '/firings_true.mda',
        units_true=[],
        json_out={'ext': '.json'},
        html_out={'ext': '.html'},
        _container='default',
        _force_run=True
    )
    elapsed = time.time() - timer
    print('################ ELAPSED for comparison (sec): {}'.format(elapsed))

    X = mt.loadObject(path=comparison.outputs['json_out'])
    accuracies = [float(a['accuracy']) for a in X.values()]
    avg_accuracy = np.mean(accuracies)

    print('Average accuracy: {}'.format(avg_accuracy))