Ejemplo n.º 1
0
    def javascript_state_changed(self, prev_state, state):
        self._set_status('running', 'Running TwoPhotonSeries')

        nwb_path = state.get('nwb_path', None)
        download_from = state.get('download_from', [])

        if not nwb_path:
            self._set_error('Missing nwb_path')
            return

        mt.configDownloadFrom(download_from)

        nwb_path2 = mt.realizeFile(nwb_path)
        if not nwb_path2:
            self._set_error('Unable to realize nwb file: {}'.format(nwb_path))
            return
        
        self._set_status('running', 'Extracting .mp4 data')
        outputs = ExtractTwoPhotonSeriesMp4.execute(nwb_in=nwb_path2, mp4_out={'ext': '.mp4'}).outputs

        self._set_status('running', 'Reading .mp4 data')
        mp4_fname = mt.realizeFile(outputs['mp4_out'])
        with open(mp4_fname, 'rb') as f:
            video_data = f.read()

        self._set_status('running', 'Encoding .mp4 data')
        video_data_b64 = base64.b64encode(video_data).decode()
        video_url = 'data:video/mp4;base64,{}'.format(video_data_b64)

        self._set_status('running', 'Setting .mp4 data to python state')
        self.set_python_state(dict(
            video_url=video_url,
            status='finished',
            status_message=''
        ))
def main():

    parser = argparse.ArgumentParser(
        description='update spike-front analysis history')
    parser.add_argument(
        '--delete-and-init',
        help='Use this only the first time. Will wipe out the history.',
        action='store_true')
    parser.add_argument('--show',
                        help='Show the analysis history.',
                        action='store_true')

    args = parser.parse_args()

    history_path = 'key://pairio/spikeforest/spike-front-analysis-history.json'
    results_path = 'key://pairio/spikeforest/spike-front-results.json'

    mt.configDownloadFrom(['spikeforest.kbucket', 'spikeforest.public'])
    if args.delete_and_init:
        print('Initializing analysis history...')
        mt.createSnapshot(mt.saveObject(object=dict(analyses=[])),
                          dest_path=history_path,
                          upload_to='spikeforest.public')
        print('Done.')
        return

    print('Loading history...')
    history = mt.loadObject(path=history_path)
    assert history, 'Unable to load history'

    if args.show:
        for aa in history['analyses']:
            print(
                '==============================================================='
            )
            print('ANALYSIS: {}'.format(aa['General']['dateUpdated']))
            print('PATH: {}'.format(aa['path']))
            print(json.dumps(aa['General'], indent=4))
            print('')
        return

    print('Loading current results...')
    spike_front_results = mt.loadObject(path=results_path)
    assert spike_front_results, 'Unable to load current results'

    sha1_url = mt.saveObject(object=spike_front_results,
                             basename='analysis.json')
    for aa in history['analyses']:
        if aa['path'] == sha1_url:
            print('Analysis already stored in history.')
            return

    history['analyses'].append(
        dict(path=sha1_url, General=spike_front_results['General'][0]))

    print('Saving updated history to {}'.format(history_path))
    mt.createSnapshot(mt.saveObject(object=history),
                      dest_path=history_path,
                      upload_to='spikeforest.public')
    print('Done.')
Ejemplo n.º 3
0
 def javascript_state_changed(self, prev_state, state):
     self.set_python_state(dict(status='running', status_message='Running'))
     mt.configDownloadFrom(state.get('download_from', []))
     path = state.get('path', None)
     if path:
         self.set_python_state(
             dict(status_message='Realizing file: {}'.format(path)))
         if path.endswith('.csv'):
             path2 = mt.realizeFile(path)
             if not path2:
                 self.set_python_state(
                     dict(
                         status='error',
                         status_message='Unable to realize file: {}'.format(
                             path)))
                 return
             self.set_python_state(dict(status_message='Loading locatoins'))
             x = np.genfromtxt(path2, delimiter=',')
             locations = x.T
             num_elec = x.shape[0]
             labels = ['{}'.format(a) for a in range(1, num_elec + 1)]
         else:
             raise Exception('Unexpected file type for {}'.format(path))
     else:
         locations = [[0, 0], [1, 0], [1, 1], [2, 1]]
         labels = ['1', '2', '3', '4']
     state = dict()
     state['locations'] = locations
     state['labels'] = labels
     state['status'] = 'finished'
     self.set_python_state(state)
Ejemplo n.º 4
0
def do_sorting_test(sorting_processor,
                    params,
                    recording_dir,
                    assert_avg_accuracy,
                    container='default'):
    mt.configDownloadFrom('spikeforest.kbucket')

    recdir = recording_dir
    mt.createSnapshot(path=recdir, download_recursive=True)
    sorting = sorting_processor.execute(recording_dir=recdir,
                                        firings_out={'ext': '.mda'},
                                        **params,
                                        _container=container,
                                        _force_run=True)

    comparison = sa.GenSortingComparisonTable.execute(
        firings=sorting.outputs['firings_out'],
        firings_true=recdir + '/firings_true.mda',
        units_true=[],
        json_out={'ext': '.json'},
        html_out={'ext': '.html'},
        _container=None,
        _force_run=True)

    X = mt.loadObject(path=comparison.outputs['json_out'])
    accuracies = [float(a['accuracy']) for a in X.values()]
    avg_accuracy = np.mean(accuracies)

    print('Average accuracy: {}'.format(avg_accuracy))

    assert (avg_accuracy >= assert_avg_accuracy)
Ejemplo n.º 5
0
 def javascript_state_changed(self, prev_state, state):
     self._set_status('running', 'Running NWBFile')
     mt.configDownloadFrom(state.get('download_from', []))
     path = state.get('path', None)
     if path:
         if path.endswith('.nwb'):
             self._set_status('running',
                              'Reading nwb file: {}'.format(path))
             obj = nwb_to_dict(path,
                               use_cache=False,
                               exclude_data=True,
                               verbose=False)
             self._set_status('running',
                              'Finished nwb file: {}'.format(path))
         else:
             self._set_status('running',
                              'Realizing object: {}'.format(path))
             obj = mt.loadObject(path=path)
         if not obj:
             self._set_error('Unable to realize object: {}'.format(path))
             return
         self.set_python_state(
             dict(status='finished',
                  status_message='finished loading: {}'.format(path),
                  object=obj))
     else:
         self._set_error('Missing path')
Ejemplo n.º 6
0
    def javascript_state_changed(self, prev_state, state):
        if not self._recording:
            recordingPath = state.get('recordingPath', None)
            if not recordingPath:
                return
            self.set_python_state(dict(status_message='Loading recording'))
            mt.configDownloadFrom(state.get('download_from'))
            X = SFMdaRecordingExtractor(dataset_directory=recordingPath,
                                        download=True)
            self.set_python_state(
                dict(numChannels=X.get_num_channels(),
                     numTimepoints=X.get_num_frames(),
                     samplerate=X.get_sampling_frequency(),
                     status_message='Loaded recording.'))
            self._recording = X
        else:
            X = self._recording

        SR = state.get('segmentsRequested', {})
        for key in SR.keys():
            aa = SR[key]
            if not self.get_python_state(key, None):
                self.set_python_state(
                    dict(status_message='Loading segment {}'.format(key)))
                data0 = self._load_data(aa['ds'], aa['ss'])
                data0_base64 = _mda32_to_base64(data0)
                state0 = {}
                state0[key] = dict(data=data0_base64, ds=aa['ds'], ss=aa['ss'])
                self.set_python_state(state0)
                self.set_python_state(
                    dict(status_message='Loaded segment {}'.format(key)))
Ejemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser(
        description=help_txt, formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('--output_dir',
                        help='The output directory for saving the files.')
    parser.add_argument('--download-from',
                        help='The output directory for saving the files.',
                        required=False,
                        default='spikeforest.kbucket')
    parser.add_argument('--key_path', help='Key path to retrieve data from')

    args = parser.parse_args()

    output_dir = args.output_dir

    if os.path.exists(output_dir):
        raise Exception(
            'Output directory already exists: {}'.format(output_dir))

    os.mkdir(output_dir)

    mt.configDownloadFrom(args.download_from)

    print('Loading spike-front results object...')
    obj = mt.loadObject(path=args.key_path)

    StudySets = obj['StudySets']
    SortingResults = obj['SortingResults']
    Sorters = obj['Sorters']
    Algorithms = obj['Algorithms']
    StudyAnalysisResults = obj['StudyAnalysisResults']
    General = obj['General']

    print('Saving {} study sets to {}/StudySets.json'.format(
        len(StudySets), output_dir))
    mt.saveObject(object=StudySets, dest_path=output_dir + '/StudySets.json')

    print('Saving {} sorting results to {}/SortingResults.json'.format(
        len(SortingResults), output_dir))
    mt.saveObject(object=SortingResults,
                  dest_path=output_dir + '/SortingResults.json')

    print('Saving {} sorters to {}/Sorters.json'.format(
        len(Sorters), output_dir))
    mt.saveObject(object=Sorters, dest_path=output_dir + '/Sorters.json')

    print('Saving {} algorithms to {}/Algorithms.json'.format(
        len(Algorithms), output_dir))
    mt.saveObject(object=Algorithms, dest_path=output_dir + '/Algorithms.json')

    print('Saving {} study analysis results to {}/StudyAnalysisResults.json'.
          format(len(StudySets), output_dir))
    mt.saveObject(object=StudyAnalysisResults,
                  dest_path=output_dir + '/StudyAnalysisResults.json')

    print('Saving general info to {}/General.json'.format(output_dir))
    mt.saveObject(object=General, dest_path=output_dir + '/General.json')
Ejemplo n.º 8
0
def main():
    parser = argparse.ArgumentParser(description='Show the publicly downloadable studies of SpikeForest')
    parser.add_argument('--group_names', help='Comma-separated list of recording group names.', required=False, default=None)

    args = parser.parse_args()

    mt.configDownloadFrom(['spikeforest.public'])

    if args.group_names is not None:
        group_names = args.output_ids.split(',')
    else:
        group_names = [
            'paired_boyden32c',
            'paired_crcns',
            'paired_mea64c',
            'paired_kampff',
            'synth_bionet',
            'synth_magland',
            'manual_franklab',
            'synth_mearec_neuronexus',
            'synth_mearec_tetrode',
            'synth_visapy',
            'hybrid_janelia'
        ]
    print('Using group names: ', group_names)

    studies = []
    study_sets = []
    recordings = []
    for group_name in group_names:
        print('RECORDING GROUP: {}'.format(group_name))
        output_path = ('key://pairio/spikeforest/spikeforest_recording_group.{}.json').format(group_name)
        obj = mt.loadObject(path=output_path)
        if obj:
            studies = studies + obj['studies']
            study_sets = study_sets + obj.get('study_sets', [])
            recordings = recordings + obj['recordings']
            study_sets_by_study = dict()
            for study in obj['studies']:
                study_sets_by_study[study['name']] = study['study_set']
            for rec in obj['recordings']:
                if rec.get('public', False):
                    study_set = study_sets_by_study.get(rec.get('study', ''), '')
                    print('{}/{}/{}: {}'.format(study_set, rec.get('study', ''), rec.get('name', ''), rec.get('directory', '')))
        else:
            print('WARNING: unable to load object: ' + output_path)

    print('')
    print('ALL GROUPS')
    study_sets_by_study = dict()
    for study in studies:
        study_sets_by_study[study['name']] = study['study_set']
    for rec in recordings:
        if rec.get('public', False):
            study_set = study_sets_by_study.get(rec.get('study', ''), '')
            print('- {}/{}/{}: `{}`'.format(study_set, rec.get('study', ''), rec.get('name', ''), rec.get('directory', '')))
Ejemplo n.º 9
0
    def javascript_state_changed(self, prev_state, state):
        self.set_python_state(dict(status='running', status_message='Running'))
        mt.configDownloadFrom(state.get('download_from', []))

        max_samples = state.get('max_samples')
        max_dt_msec = state.get('max_dt_msec')
        bin_size_msec = state.get('bin_size_msec')
        if not max_dt_msec:
            return

        firings_path = state.get('firingsPath', None)
        if not firings_path:
            self.set_python_state(dict(
                status='error',
                status_message='No firingsPath provided'
            ))
            return

        samplerate = state.get('samplerate', None)
        if not samplerate:
            self.set_python_state(dict(
                status='error',
                status_message='No samplerate provided'
            ))
            return

        self.set_python_state(dict(status_message='Realizing file: {}'.format(firings_path)))
        firings_path2 = mt.realizeFile(firings_path)
        if not firings_path2:
            self.set_python_state(dict(
                status='error',
                status_message='Unable to realize file: {}'.format(firings_path)
            ))
            return

        result = ComputeAutocorrelograms.execute(
            firings_path=firings_path2,
            samplerate=samplerate,
            max_samples=max_samples,
            bin_size_msec=bin_size_msec,
            max_dt_msec=max_dt_msec,
            json_out=dict(ext='.json')
        )
        if result.retcode != 0:
            self.set_python_state(dict(
                status='error',
                status_message='Error computing autocorrelogram.'
            ))
            return

        output = mt.loadObject(path=result.outputs['json_out'])
        self.set_python_state(dict(
            status='finished',
            output=output
        ))
def main():
    path = mt.createSnapshot(path='recordings_out')
    mt.configDownloadFrom('spikeforest.public')
    X = mt.readDir(path)
    for study_set_name, d in X['dirs'].items():
        study_sets = []
        studies = []
        recordings = []
        study_sets.append(dict(
            name=study_set_name + '_b',
            info=dict(),
            description=''
        ))
        for study_name, d2 in d['dirs'].items():
            study_dir = path + '/' + study_set_name + '/' + study_name
            study0 = dict(
                name=study_name,
                study_set=study_set_name + '_b',
                directory=study_dir,
                description=''
            )
            studies.append(study0)
            index_within_study = 0
            for recording_name, d3 in d2['dirs'].items():
                recdir = study_dir + '/' + recording_name
                recordings.append(dict(
                    name=recording_name,
                    study=study_name,
                    directory=recdir,
                    firings_true=recdir + '/firings_true.mda',
                    index_within_study=index_within_study,
                    description='One of the recordings in the {} study'.format(study_name)
                ))
                index_within_study = index_within_study + 1

        print('Saving object...')
        group_name = study_set_name + '_b'
        address = mt.saveObject(
            object=dict(
                studies=studies,
                recordings=recordings,
                study_sets=study_sets
            ),
            key=dict(name='spikeforest_recording_group', group_name=group_name),
            upload_to='spikeforest.public'
        )
        if not address:
            raise Exception('Problem uploading object to kachery')

        output_fname = 'key://pairio/spikeforest/spikeforest_recording_group.{}.json'.format(group_name)
        print('Saving output to {}'.format(output_fname))
        mt.createSnapshot(path=address, dest_path=output_fname)
Ejemplo n.º 11
0
def mem_profile_test(sorting_processor, params, recording_dir, container='default', _keep_temp_files=True):
    mt.configDownloadFrom('spikeforest.public')
    params['fMemProfile'] = True
    recdir = recording_dir
    mt.createSnapshot(path=recdir, download_recursive=True)
    sorting = sorting_processor.execute(
        recording_dir=recdir,
        firings_out={'ext': '.mda'},
        **params,
        _container=container,
        _force_run=True,
        _keep_temp_files=_keep_temp_files
    )
Ejemplo n.º 12
0
def main():
    path = 'sha1dir://8516cc54587e0c5ddd0709154e7f609b9b7884b4'
    mt.configDownloadFrom('spikeforest.public')
    X = mt.readDir(path)
    for study_set_name, d in X['dirs'].items():
        for study_name, d2 in d['dirs'].items():
            for recording_name, d3 in d2['dirs'].items():
                x = mt.loadObject(path=path + '/' + study_set_name + '/' +
                                  study_name + '/' + recording_name +
                                  '.runtime_info.json')
                print('{}/{}/{}\n{} sec\n'.format(study_set_name, study_name,
                                                  recording_name,
                                                  x['elapsed_sec']))
Ejemplo n.º 13
0
    def javascript_state_changed(self, prev_state, state):
        self.set_state(dict(status='running', status_message='Running'))

        vtp_path = state.get('vtp_path', None)
        download_from = state.get('download_from', None)
        scalar_info = state.get('scalar_info', None)
        vector_field_info = state.get('vector_field_info', None)
        arrow_subsample_factor = state.get('arrow_subsample_factor', None)

        if not vtp_path:
            self.set_state(dict(status='error', status_message='No vtp_path'))
            return

        if download_from:
            mt.configDownloadFrom(download_from)
        fname = mt.realizeFile(path=vtp_path)

        reader = vtkXMLPolyDataReader()
        reader.SetFileName(fname)
        reader.Update()
        X = reader.GetOutput()
        vertices = vtk_to_numpy(X.GetPoints().GetData()).T
        faces = vtk_to_numpy(X.GetPolys().GetData())

        if scalar_info:
            scalars = vtk_to_numpy(X.GetPointData().GetArray(
                scalar_info['name']))
            scalars = scalars[:, scalar_info['component']]
        else:
            scalars = None

        if vector_field_info:
            vector_field = vtk_to_numpy(X.GetPointData().GetArray(
                vector_field_info['name']))
            vector_field = vector_field[:, vector_field_info['components']]
            arrows = [
                dict(start=vertices[:, j] - vector_field[j, :].T / 2,
                     end=vertices[:, j] + vector_field[j, :].T / 2) for j in
                range(0, vector_field.shape[0], arrow_subsample_factor)
            ]
        else:
            arrows = None

        self.set_state(
            dict(status='finished',
                 vertices=vertices,
                 faces=faces,
                 scalars=scalars,
                 arrows=arrows))
Ejemplo n.º 14
0
    def javascript_state_changed(self, prev_state, state):
        self.set_python_state(dict(status='running', status_message='Running'))

        # get javascript state
        download_from = state.get('download_from', [])
        path = state.get('path', None)
        name = state.get('name', None)

        if path and name:
            mt.configDownloadFrom(download_from)
            if path.endswith('.nwb'):
                self.set_python_state(
                    dict(status_message='Realizing object from nwb file: {}'.
                         format(path)))
                obj = nwb_to_dict(path, use_cache=True)
            else:
                self.set_python_state(
                    dict(status_message='Realizing object: {}'.format(path)))
                obj = mt.loadObject(path=path)
            if not obj:
                self.set_python_state(
                    dict(status='error',
                         status_message='Unable to realize object: {}'.format(
                             path)))
                return
            datasets = obj['general']['subject']['cortical_surfaces'][name][
                '_datasets']
            faces0 = np.load(mt.realizeFile(datasets['faces']['_data']))
            vertices = np.load(mt.realizeFile(datasets['vertices']['_data'])).T

            # there's a better way to do the following
            # (need to get it into a single vector format)
            faces = []
            for j in range(faces0.shape[0]):
                # 3 = #vertices in polygon (assume a triangulation)
                faces.extend([3, faces0[j, 0], faces0[j, 1], faces0[j, 2]])
            faces = np.array(faces)

            # return this python state to the javascript
            self.set_python_state(
                dict(faces=faces,
                     vertices=vertices,
                     status='finished',
                     status_message='Done.'))
        else:
            self.set_python_state(
                dict(status='error',
                     status_message='Missing path and/or name'))
Ejemplo n.º 15
0
 def javascript_state_changed(self, prev_state, state):
     self.set_python_state(dict(status='running', status_message='Running'))
     mt.configDownloadFrom(state.get('download_from', []))
     path = state.get('path', None)
     if path:
         self.set_python_state(
             dict(status_message='Realizing object: {}'.format(path)))
         obj = mt.loadObject(path=path)
         if not obj:
             self.set_python_state(
                 dict(status='error',
                      status_message='Unable to realize object: {}'.format(
                          path)))
             return
         state['object'] = obj
         state['status'] = 'finished'
         self.set_python_state(state)
Ejemplo n.º 16
0
    def javascript_state_changed(self, prev_state, state):
        self.set_python_state(dict(status='running', status_message='Running'))
        mt.configDownloadFrom(state.get('download_from', []))
        nwb_path = state.get('nwb_path', None)
        downsample_factor = state.get('downsample_factor', 1)
        if nwb_path:
            if nwb_path.endswith('.nwb'):
                self.set_python_state(
                    dict(status_message='Realizing object from nwb file: {}'.
                         format(nwb_path)))
                obj = h5_to_dict(nwb_path, use_cache=True)
            else:
                self.set_python_state(
                    dict(status_message='Realizing object: {}'.format(
                        nwb_path)))
                obj = mt.loadObject(path=nwb_path)
            if not obj:
                self.set_python_state(
                    dict(status='error',
                         status_message='Unable to realize object: {}'.format(
                             nwb_path)))
                return
            try:
                positions_path = obj['processing']['Behavior']['Position'][
                    'Position']['_datasets']['data']['_data']
            except:
                self.set_python_state(
                    dict(status='error',
                         status_message=
                         'Problem extracting behavior positions in file: {}'.
                         format(nwb_path)))
                return
            positions = np.load(mt.realizeFile(path=positions_path))
            positions = positions[::downsample_factor, :]

            self.set_python_state(
                dict(status_message='Finished loading positions'))
            state['positions'] = positions
            state['status'] = 'finished'
            self.set_python_state(state)
        else:
            self.set_python_state(
                dict(status='error',
                     status_message='Missing in state: nwb_path'))
Ejemplo n.º 17
0
    def javascript_state_changed(self, prev_state, state):
        self._set_status('running', 'Running surface3D')
        mt.configDownloadFrom(state.get('download_from', []))

        python_state = dict()

        path0 = state.get('faces_path', None)
        if path0:
            x = mt.realizeFile(path0)
            if not x:
                self._set_error('Unable to load file: {}'.format(path0))
                return
            faces0 = np.load(x)
            # there's a better way to do the following
            # (need to get it into a single vector format)
            faces = []
            for j in range(faces0.shape[0]):
                # 3 = #vertices in polygon (assume a triangulation)
                faces.extend([3, faces0[j, 0], faces0[j, 1], faces0[j, 2]])
            faces = np.array(faces)
            python_state['faces'] = faces

        path0 = state.get('vertices_path', None)
        if path0:
            x = mt.realizeFile(path0)
            if not x:
                self._set_error('Unable to load file: {}'.format(path0))
                return
            vertices0 = np.load(x)
            python_state['vertices'] = vertices0.T

        path0 = state.get('scalars_path', None)
        if path0:
            x = mt.realizeFile(path0)
            if not x:
                self._set_error('Unable to load file: {}'.format(path0))
                return
            x = np.load(x)
            python_state['scalars'] = x

        python_state['status'] = 'finished'
        python_state['status_message'] = 'finished'
        self.set_python_state(python_state)
Ejemplo n.º 18
0
def do_sorting_test(
        sorting_processor,
        params,
        recording_dir,
        container='default',
        force_run=True,
        _keep_temp_files=False
    ):
    mt.configDownloadFrom(['spikeforest.kbucket', 'spikeforest.public'])

    recdir = recording_dir
    mt.createSnapshot(path=recdir, download_recursive=True)
    timer = time.time()
    sorting = sorting_processor.execute(
        recording_dir=recdir,
        firings_out={'ext': '.mda'},
        **params,
        _container=container,
        _force_run=force_run,
        _keep_temp_files=_keep_temp_files
    )
    elapsed = time.time() - timer
    print('################ ELAPSED for sorting (sec): {}'.format(elapsed))

    timer = time.time()
    comparison = sa.GenSortingComparisonTable.execute(
        firings=sorting.outputs['firings_out'],
        firings_true=recdir + '/firings_true.mda',
        units_true=[],
        json_out={'ext': '.json'},
        html_out={'ext': '.html'},
        _container='default',
        _force_run=True
    )
    elapsed = time.time() - timer
    print('################ ELAPSED for comparison (sec): {}'.format(elapsed))

    X = mt.loadObject(path=comparison.outputs['json_out'])
    accuracies = [float(a['accuracy']) for a in X.values()]
    avg_accuracy = np.mean(accuracies)

    print('Average accuracy: {}'.format(avg_accuracy))
Ejemplo n.º 19
0
def main():
    parser = argparse.ArgumentParser(
        description='Upload public files, e.g., console outputs')
    parser.add_argument(
        '--output_ids',
        help='Comma-separated list of IDs of the analysis outputs to include.',
        required=False,
        default=None)

    args = parser.parse_args()

    mt.configDownloadFrom(['spikeforest.kbucket'])

    if args.output_ids is not None:
        output_ids = args.output_ids.split(',')
    else:
        output_ids = [
            'paired_boyden32c', 'paired_crcns', 'paired_mea64c',
            'paired_kampff', 'paired_monotrode', 'synth_monotrode',
            'synth_bionet', 'synth_magland', 'manual_franklab',
            'synth_mearec_neuronexus', 'synth_mearec_tetrode', 'synth_visapy',
            'hybrid_janelia'
        ]

    print('Using output ids: ', output_ids)

    for output_id in output_ids:
        print('Loading output object: {}'.format(output_id))
        output_path = (
            'key://pairio/spikeforest/spikeforest_analysis_results.{}.json'
        ).format(output_id)
        obj = mt.loadObject(path=output_path)
        paths = [
            sr['console_out'] for sr in obj['sorting_results']
            if 'console_out' in sr
        ]
        print('{}: {} sorting results - {} with console_out'.format(
            output_id, len(obj['sorting_results']), len(paths)))
        mt.createSnapshots(paths=paths, upload_to='spikeforest.public')
Ejemplo n.º 20
0
    def javascript_state_changed(self, prev_state, state):
        self._set_status('running', 'Running')

        path = state.get('path', None)
        download_from = state.get('download_from', [])
        mt.configDownloadFrom(download_from)

        if not path:
            self._set_status('finished')
            return

        path = mt.realizeFile(path)
        if not path:
            self._set_error('Unable to realize file.')
            return

        with open(path, 'rb') as f:
            video_data = f.read()
        video_data_b64 = base64.b64encode(video_data).decode()
        self.set_python_state(
            dict(video_data_b64=video_data_b64,
                 status='finished',
                 status_message=''))
Ejemplo n.º 21
0
    def javascript_state_changed(self, prev_state, state):
        path = state.get('path', None)
        download_from = state.get('download_from', [])
        mt.configDownloadFrom(download_from)

        if not path:
            self.set_python_state(
                dict(status='error', status_message='No path provided.'))
            return

        self.set_python_state(
            dict(status='running', status_message='Loading: {}'.format(path)))

        obj = mt.loadObject(path=path)
        if not obj:
            self.set_python_state(
                dict(status='error',
                     status_message='Unable to realize object: {}'.format(
                         path)))
            return

        obj['StudyAnalysisResults'] = None

        self.set_python_state(dict(status='finished', object=obj))
Ejemplo n.º 22
0
def main():
    parser = argparse.ArgumentParser(
        description=
        'Generate unit detail data (including spikesprays) for website')
    parser.add_argument(
        '--output_ids',
        help=
        'Comma-separated list of IDs of the analysis outputs to include in the website.',
        required=False,
        default=None)

    args = parser.parse_args()

    use_slurm = True

    mt.configDownloadFrom(['spikeforest.kbucket'])

    if args.output_ids is not None:
        output_ids = args.output_ids.split(',')
    else:
        output_ids = [
            'paired_boyden32c',
            'paired_crcns',
            'paired_mea64c',
            'paired_kampff',
            'paired_monotrode',
            'synth_monotrode',
            # 'synth_bionet',
            'synth_magland',
            'manual_franklab',
            'synth_mearec_neuronexus',
            'synth_mearec_tetrode',
            'synth_visapy',
            'hybrid_janelia'
        ]
    print('Using output ids: ', output_ids)

    print(
        '******************************** LOADING ANALYSIS OUTPUT OBJECTS...')
    for output_id in output_ids:
        slurm_working_dir = 'tmp_slurm_job_handler_' + _random_string(5)
        job_handler = mlpr.SlurmJobHandler(working_dir=slurm_working_dir)
        if use_slurm:
            job_handler.addBatchType(name='default',
                                     num_workers_per_batch=20,
                                     num_cores_per_job=1,
                                     time_limit_per_batch=1800,
                                     use_slurm=True,
                                     additional_srun_opts=['-p ccm'])
        else:
            job_handler.addBatchType(
                name='default',
                num_workers_per_batch=multiprocessing.cpu_count(),
                num_cores_per_job=1,
                use_slurm=False)
        with mlpr.JobQueue(job_handler=job_handler) as JQ:
            print('=============================================', output_id)
            print('Loading output object: {}'.format(output_id))
            output_path = (
                'key://pairio/spikeforest/spikeforest_analysis_results.{}.json'
            ).format(output_id)
            obj = mt.loadObject(path=output_path)
            # studies = obj['studies']
            # study_sets = obj.get('study_sets', [])
            # recordings = obj['recordings']
            sorting_results = obj['sorting_results']

            print(
                'Determining sorting results to process ({} total)...'.format(
                    len(sorting_results)))
            sorting_results_to_process = []
            for sr in sorting_results:
                key = dict(name='unit-details-v0.1.0',
                           recording_directory=sr['recording']['directory'],
                           firings_true=sr['firings_true'],
                           firings=sr['firings'])
                val = mt.getValue(key=key, collection='spikeforest')
                if not val:
                    sr['key'] = key
                    sorting_results_to_process.append(sr)

            print('Need to process {} of {} sorting results'.format(
                len(sorting_results_to_process), len(sorting_results)))

            recording_directories_to_process = sorted(
                list(
                    set([
                        sr['recording']['directory']
                        for sr in sorting_results_to_process
                    ])))
            print('{} recording directories to process'.format(
                len(recording_directories_to_process)))

            print('Filtering recordings...')
            filter_jobs = FilterTimeseries.createJobs([
                dict(recording_directory=recdir,
                     timeseries_out={'ext': '.mda'},
                     _timeout=600)
                for recdir in recording_directories_to_process
            ])
            filter_results = [job.execute() for job in filter_jobs]

            JQ.wait()

            filtered_timeseries_by_recdir = dict()
            for i, recdir in enumerate(recording_directories_to_process):
                result0 = filter_results[i]
                if result0.retcode != 0:
                    raise Exception(
                        'Problem computing filtered timeseries for recording: {}'
                        .format(recdir))
                filtered_timeseries_by_recdir[recdir] = result0.outputs[
                    'timeseries_out']

            print('Creating spike sprays...')
            for sr in sorting_results_to_process:
                rec = sr['recording']
                study_name = rec['study']
                rec_name = rec['name']
                sorter_name = sr['sorter']['name']

                print('====== COMPUTING {}/{}/{}'.format(
                    study_name, rec_name, sorter_name))

                if sr.get('comparison_with_truth', None) is not None:
                    cwt = mt.loadObject(
                        path=sr['comparison_with_truth']['json'])

                    filtered_timeseries = filtered_timeseries_by_recdir[
                        rec['directory']]

                    spike_spray_job_objects = []
                    list0 = list(cwt.values())
                    for _, unit in enumerate(list0):
                        # print('')
                        # print('=========================== {}/{}/{} unit {} of {}'.format(study_name, rec_name, sorter_name, ii + 1, len(list0)))
                        # ssobj = create_spikesprays(rx=rx, sx_true=sx_true, sx_sorted=sx, neighborhood_size=neighborhood_size, num_spikes=num_spikes, unit_id_true=unit['unit_id'], unit_id_sorted=unit['best_unit'])

                        spike_spray_job_objects.append(
                            dict(args=dict(
                                recording_directory=rec['directory'],
                                filtered_timeseries=filtered_timeseries,
                                firings_true=os.path.join(
                                    rec['directory'], 'firings_true.mda'),
                                firings_sorted=sr['firings'],
                                unit_id_true=unit['unit_id'],
                                unit_id_sorted=unit['best_unit'],
                                json_out={'ext': '.json'},
                                _container='default',
                                _timeout=180),
                                 study_name=study_name,
                                 rec_name=rec_name,
                                 sorter_name=sorter_name,
                                 unit=unit))
                    spike_spray_jobs = CreateSpikeSprays.createJobs(
                        [obj['args'] for obj in spike_spray_job_objects])
                    spike_spray_results = [
                        job.execute() for job in spike_spray_jobs
                    ]

                    sr['spike_spray_job_objects'] = spike_spray_job_objects
                    sr['spike_spray_results'] = spike_spray_results

            JQ.wait()

        for sr in sorting_results_to_process:
            rec = sr['recording']
            study_name = rec['study']
            rec_name = rec['name']
            sorter_name = sr['sorter']['name']

            print('====== SAVING {}/{}/{}'.format(study_name, rec_name,
                                                  sorter_name))

            if sr.get('comparison_with_truth', None) is not None:
                spike_spray_job_objects = sr['spike_spray_job_objects']
                spike_spray_results = sr['spike_spray_results']

                unit_details = []
                ok = True
                for i, result in enumerate(spike_spray_results):
                    obj0 = spike_spray_job_objects[i]
                    if result.retcode != 0:
                        print('WARNING: Error creating spike sprays for job:')
                        print(spike_spray_job_objects[i])
                        ok = False
                        break
                    ssobj = mt.loadObject(path=result.outputs['json_out'])
                    if ssobj is None:
                        raise Exception(
                            'Problem loading spikespray object output.')
                    address = mt.saveObject(object=ssobj,
                                            upload_to='spikeforest.kbucket')
                    unit = obj0['unit']
                    unit_details.append(
                        dict(studyName=obj0['study_name'],
                             recordingName=obj0['rec_name'],
                             sorterName=obj0['sorter_name'],
                             trueUnitId=unit['unit_id'],
                             sortedUnitId=unit['best_unit'],
                             spikeSprayUrl=mt.findFile(
                                 path=address,
                                 remote_only=True,
                                 download_from='spikeforest.kbucket'),
                             _container='default'))
                if ok:
                    mt.saveObject(collection='spikeforest',
                                  key=sr['key'],
                                  object=unit_details,
                                  upload_to='spikeforest.public')
def main():
    mt.configDownloadFrom('spikeforest.public')
    templates_path = 'sha1dir://95dba567b5168bacb480411ca334ffceb96b8c19.2019-06-11.templates'
    recordings_path = 'recordings_out'

    tempgen_tetrode = templates_path + '/templates_tetrode.h5'
    tempgen_neuronexus = templates_path + '/templates_neuronexus.h5'
    tempgen_neuropixels = templates_path + '/templates_neuropixels.h5'
    tempgen_neuronexus_drift = templates_path + '/templates_neuronexus_drift.h5'

    noise_level = [10, 20]
    duration = 600
    bursting = [False, True]
    nrec = 2  # change this to 10
    ei_ratio = 0.8
    rec_dict = {
        'tetrode': {
            'ncells': [10, 20],
            'tempgen': tempgen_tetrode,
            'drifting': False
        },
        'neuronexus': {
            'ncells': [10, 20, 40],
            'tempgen': tempgen_neuronexus,
            'drifting': False
        },
        'neuropixels': {
            'ncells': [20, 40, 60],
            'tempgen': tempgen_neuropixels,
            'drifting': False
        },
        'neuronexus_drift': {
            'ncells': [10, 20, 40],
            'tempgen': tempgen_neuronexus_drift,
            'drifting': True
        }
    }

    # optional: if drifting change drift velocity
    # recording_params['recordings']['drift_velocity] = ...

    # Generate and save recordings
    if os.path.exists(recordings_path):
        shutil.rmtree(recordings_path)
    os.mkdir(recordings_path)

    # Set up slurm configuration
    slurm_working_dir = 'tmp_slurm_job_handler_' + _random_string(5)
    job_handler = mlpr.SlurmJobHandler(working_dir=slurm_working_dir)
    use_slurm = True
    job_timeout = 3600 * 4
    if use_slurm:
        job_handler.addBatchType(name='default',
                                 num_workers_per_batch=4,
                                 num_cores_per_job=6,
                                 time_limit_per_batch=job_timeout * 3,
                                 use_slurm=True,
                                 max_simultaneous_batches=20,
                                 additional_srun_opts=['-p ccm'])
    else:
        job_handler.addBatchType(
            name='default',
            num_workers_per_batch=multiprocessing.cpu_count(),
            num_cores_per_job=2,
            max_simultaneous_batches=1,
            use_slurm=False)
    with mlpr.JobQueue(job_handler=job_handler) as JQ:
        results_to_write = []
        for rec_type in rec_dict.keys():
            study_set_name = 'SYNTH_MEAREC_{}'.format(rec_type.upper())
            os.mkdir(recordings_path + '/' + study_set_name)
            params = dict()
            params['duration'] = duration
            params['drifting'] = rec_dict[rec_type]['drifting']
            # reduce minimum distance for dense recordings
            params['min_dist'] = 15
            for ncells in rec_dict[rec_type]['ncells']:
                # changing number of cells
                n_exc = int(ei_ratio *
                            10)  # intentionally replaced nrec by 10 here
                params['n_exc'] = n_exc
                params['n_inh'] = ncells - n_exc
                for n in noise_level:
                    # changing noise level
                    params['noise_level'] = n
                    for b in bursting:
                        bursting_str = ''
                        if b:
                            bursting_str = '_bursting'
                        study_name = 'synth_mearec_{}_noise{}_K{}{}'.format(
                            rec_type, n, ncells, bursting_str)
                        os.mkdir(recordings_path + '/' + study_set_name + '/' +
                                 study_name)
                        for i in range(nrec):
                            # set random seeds
                            params[
                                'seed'] = i  # intentionally doing it this way

                            # changing bursting and shape modulation
                            print('Generating', rec_type, 'recording with',
                                  ncells, 'noise level', n, 'bursting', b)
                            params['bursting'] = b
                            params['shape_mod'] = b
                            templates0 = mt.realizeFile(
                                path=rec_dict[rec_type]['tempgen'])
                            result0 = GenerateMearecRecording.execute(
                                **params,
                                templates_in=templates0,
                                recording_out=dict(ext='.h5'))
                            mda_output_folder = recordings_path + '/' + study_set_name + '/' + study_name + '/' + '{}'.format(
                                i)
                            results_to_write.append(
                                dict(result=result0,
                                     mda_output_folder=mda_output_folder))
        JQ.wait()

        for x in results_to_write:
            result0: mlpr.MountainJobResult = x['result']
            mda_output_folder = x['mda_output_folder']
            path = mt.realizeFile(path=result0.outputs['recording_out'])
            recording = se.MEArecRecordingExtractor(recording_path=path)
            sorting_true = se.MEArecSortingExtractor(recording_path=path)
            se.MdaRecordingExtractor.write_recording(
                recording=recording, save_path=mda_output_folder)
            se.MdaSortingExtractor.write_sorting(sorting=sorting_true,
                                                 save_path=mda_output_folder +
                                                 '/firings_true.mda')
            if result0.console_out:
                mt.realizeFile(path=result0.console_out,
                               dest_path=mda_output_folder +
                               '.console_out.txt')
            if result0.runtime_info:
                mt.saveObject(object=result0.runtime_info,
                              dest_path=mda_output_folder +
                              '.runtime_info.json')

    print('Creating and uploading snapshot...')
    sha1dir_path = mt.createSnapshot(path=recordings_path,
                                     upload_to='spikeforest.public',
                                     upload_recursive=False)
    # sha1dir_path = mt.createSnapshot(path=recordings_path, upload_to='spikeforest.kbucket', upload_recursive=True)
    print(sha1dir_path)
Ejemplo n.º 24
0
#!/usr/bin/env python

from mountaintools import client as mt

mt.configDownloadFrom('spikeforest.kbucket')


def do_prepare(recording_group, study_name):
    print(recording_group, study_name)
    X = mt.loadObject(
        path="key://pairio/spikeforest/spikeforest_recording_group.{}.json".
        format(recording_group))
    studies = [y for y in X['studies'] if (y['name'] == study_name)]
    recordings = [y for y in X['recordings'] if y['study'] == study_name]
    recordings = recordings[0:1]
    study_sets = X['study_sets']

    Y = dict(studies=studies, recordings=recordings, study_sets=study_sets)
    address = mt.saveObject(object=Y)
    assert address is not None
    dest_path = 'key://pairio/spikeforest/spikeforest_recording_group.test_{}.json'.format(
        recording_group)
    print(dest_path)
    mt.createSnapshot(path=address,
                      upload_to='spikeforest.kbucket',
                      dest_path=dest_path)


do_prepare(recording_group='synth_magland',
           study_name='synth_magland_noise10_K10_C4')
do_prepare(recording_group='paired_mea64c', study_name='paired_mea64c')
Ejemplo n.º 25
0
#!/usr/bin/env python

from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor
from mountaintools import client as mt

# Configure to download from the public spikeforest kachery node
mt.configDownloadFrom('spikeforest.public')

# Load an example tetrode recording with its ground truth
# You can also substitute any of the other available recordings
recdir = 'sha1dir://fb52d510d2543634e247e0d2d1d4390be9ed9e20.synth_magland/datasets_noise10_K10_C4/001_synth'

print('loading recording...')
recording = SFMdaRecordingExtractor(dataset_directory=recdir, download=True)
sorting_true = SFMdaSortingExtractor(firings_file=recdir + '/firings_true.mda')

# import a spike sorter from the spikesorters module of spikeforest
from spikeforestsorters import MountainSort4
import os
import shutil

# In place of MountainSort4 you could use any of the following:
#
# MountainSort4, SpykingCircus, KiloSort, KiloSort2, YASS
# IronClust, HerdingSpikes2, JRClust, Tridesclous, Klusta
# although the Matlab sorters require further setup.

# clear and create an empty output directory (keep things tidy)
if os.path.exists('test_outputs'):
    shutil.rmtree('test_outputs')
os.makedirs('test_outputs', exist_ok=True)
Ejemplo n.º 26
0
    def javascript_state_changed(self, prev_state, state):
        self.set_status('running')
        self.set_status_message('Running')

        mt.configDownloadFrom(state.get('download_from', []))
        nwb_query = state.get('nwb_query', None)
        downsample_factor = state.get('downsample_factor', 1)

        if nwb_query:
            self.set_status_message('Loading nwb object')
            obj = _load_nwb_object(nwb_query)
            if not obj:
                self.set_error('Unable to load nwb object')
                return

            self.set_status_message('Loading positions and timestamps from')
            try:
                positions_path = obj['processing']['Behavior']['Position'][
                    'Position']['_datasets']['data']['_data']
                timestamps_path = obj['processing']['Behavior']['Position'][
                    'Position']['_datasets']['timestamps']['_data']
            except:
                self.set_error(
                    'Problem extracting behavior positions or timestamps')
                return
            positions = np.load(mt.realizeFile(path=positions_path))
            positions = positions[::downsample_factor, :]
            timestamps = np.load(mt.realizeFile(path=timestamps_path))
            timestamps = timestamps[::downsample_factor]

            self.set_status_message('Loading spike times')
            try:
                spike_times_path = obj['units']['_datasets']['spike_times'][
                    '_data']
                spike_times_index = obj['units']['_datasets'][
                    'spike_times_index']['_data']
                spike_times_index_id = obj['units']['_datasets']['id']['_data']
                if 'cluster_name' in obj['units']['_datasets']:
                    cluster_names = obj['units']['_datasets']['cluster_name'][
                        '_data']
                else:
                    cluster_names = []
            except:
                self.set_error('Problem extracting spike times')
                return
            spike_times = np.load(mt.realizeFile(path=spike_times_path))

            spike_time_indices = _find_closest(timestamps, spike_times)
            spike_labels = np.zeros(spike_time_indices.shape)
            aa = 0
            for i, val in enumerate(spike_times_index):
                spike_labels[aa:val] = spike_times_index_id[i]
                aa = val

            all_unit_ids = sorted(list(set(spike_labels)))

            state['positions'] = positions
            state['status'] = 'finished'
            state['spike_time_indices'] = spike_time_indices
            state['spike_labels'] = spike_labels
            state['all_unit_ids'] = all_unit_ids
            state['cluster_names'] = cluster_names
            self.set_python_state(state)
            self.set_status('finished')
        else:
            self.set_error('Missing in state: nwb_query')
Ejemplo n.º 27
0
def main():
    from mountaintools import client as mt

    parser = argparse.ArgumentParser(
        description=
        'Generate unit detail data (including spikesprays) for website')
    parser.add_argument('analysis_path',
                        help='assembled analysis file (output.json)')
    parser.add_argument(
        '--studysets',
        help='Comma-separated list of study set names to include',
        required=False,
        default=None)
    parser.add_argument('--force-run',
                        help='Force rerunning of processing',
                        action='store_true')
    parser.add_argument(
        '--force-run-all',
        help='Force rerunning of processing including filtering',
        action='store_true')
    parser.add_argument('--parallel',
                        help='Optional number of parallel jobs',
                        required=False,
                        default='0')
    parser.add_argument('--slurm',
                        help='Path to slurm config file',
                        required=False,
                        default=None)
    parser.add_argument('--cache',
                        help='The cache database to use',
                        required=False,
                        default=None)
    parser.add_argument('--job-timeout',
                        help='Timeout for processing jobs',
                        required=False,
                        default=600)
    parser.add_argument('--log-file',
                        help='Log file for analysis progress',
                        required=False,
                        default=None)
    parser.add_argument(
        '--force-regenerate',
        help=
        'Whether to force regenerating spike sprays (for when code has changed)',
        action='store_true')
    parser.add_argument('--test',
                        help='Whether to just test by running only 1',
                        action='store_true')

    args = parser.parse_args()

    mt.configDownloadFrom(['spikeforest.kbucket'])

    with open(args.analysis_path, 'r') as f:
        analysis = json.load(f)

    if args.studysets is not None:
        studyset_names = args.studysets.split(',')
        print('Using study sets: ', studyset_names)
    else:
        studyset_names = None

    study_sets = analysis['StudySets']
    sorting_results = analysis['SortingResults']

    studies_to_include = []
    for ss in study_sets:
        if (studyset_names is None) or (ss['name'] in studyset_names):
            for study in ss['studies']:
                studies_to_include.append(study['name'])

    print('Including studies:', studies_to_include)

    print('Determining sorting results to process ({} total)...'.format(
        len(sorting_results)))
    sorting_results_to_process = []
    sorting_results_to_consider = []
    for sr in sorting_results:
        study_name = sr['studyName']
        if study_name in studies_to_include:
            if 'firings' in sr:
                if sr.get('comparisonWithTruth', None) is not None:
                    sorting_results_to_consider.append(sr)
                    key = dict(name='unit-details-v0.1.0',
                               recording_directory=sr['recordingDirectory'],
                               firings_true=sr['firingsTrue'],
                               firings=sr['firings'])
                    val = mt.getValue(key=key, collection='spikeforest')
                    if (not val) or (args.force_regenerate):
                        sr['key'] = key
                        sorting_results_to_process.append(sr)
    if args.test and len(sorting_results_to_process) > 0:
        sorting_results_to_process = [sorting_results_to_process[0]]

    print('Need to process {} of {} sorting results'.format(
        len(sorting_results_to_process), len(sorting_results_to_consider)))

    recording_directories_to_process = sorted(
        list(
            set([
                sr['recordingDirectory'] for sr in sorting_results_to_process
            ])))
    print('{} recording directories to process'.format(
        len(recording_directories_to_process)))

    if int(args.parallel) > 0:
        job_handler = hither.ParallelJobHandler(int(args.parallel))
    elif args.slurm:
        with open(args.slurm, 'r') as f:
            slurm_config = json.load(f)
        job_handler = hither.SlurmJobHandler(working_dir='tmp_slurm',
                                             **slurm_config['cpu'])
    else:
        job_handler = None

    print('Filtering recordings...')
    filter_results = []
    with hither.config(container='default',
                       cache=args.cache,
                       force_run=args.force_run_all,
                       job_handler=job_handler,
                       log_path=args.log_file,
                       exception_on_fail=True,
                       cache_failing=False,
                       rerun_failing=True,
                       job_timeout=args.job_timeout), hither.job_queue():
        for recdir in recording_directories_to_process:
            result = filter_recording.run(recording_directory=recdir,
                                          timeseries_out=hither.File())
            filter_results.append(result)
    filtered_timeseries_by_recdir = dict()
    for i, recdir in enumerate(recording_directories_to_process):
        result0 = filter_results[i]
        if not result0.success:
            raise Exception(
                'Problem computing filtered timeseries for recording: {}'.
                format(recdir))
        filtered_timeseries_by_recdir[
            recdir] = result0.outputs.timeseries_out._path

    print('Creating spike sprays...')
    with hither.config(container='default',
                       cache=args.cache,
                       force_run=args.force_run or args.force_run_all,
                       job_handler=job_handler,
                       log_path=args.log_file,
                       exception_on_fail=True,
                       cache_failing=False,
                       rerun_failing=True,
                       job_timeout=args.job_timeout), hither.job_queue():
        for sr in sorting_results_to_process:
            recdir = sr['recordingDirectory']
            study_name = sr['studyName']
            rec_name = sr['recordingName']
            sorter_name = sr['sorterName']

            print('====== COMPUTING {}/{}/{}'.format(study_name, rec_name,
                                                     sorter_name))

            cwt = ka.load_object(path=sr['comparisonWithTruth']['json'])

            filtered_timeseries = filtered_timeseries_by_recdir[recdir]

            spike_spray_results = []
            list0 = list(cwt.values())
            for _, unit in enumerate(list0):
                result = create_spike_sprays.run(
                    recording_directory=recdir,
                    filtered_timeseries=filtered_timeseries,
                    firings_true=os.path.join(recdir, 'firings_true.mda'),
                    firings_sorted=sr['firings'],
                    unit_id_true=unit['unit_id'],
                    unit_id_sorted=unit['best_unit'],
                    json_out=hither.File())
                setattr(result, 'unit', unit)
                spike_spray_results.append(result)
            sr['spike_spray_results'] = spike_spray_results

    for sr in sorting_results_to_process:
        recdir = sr['recordingDirectory']
        study_name = sr['studyName']
        rec_name = sr['recordingName']
        sorter_name = sr['sorterName']

        print('====== SAVING {}/{}/{}'.format(study_name, rec_name,
                                              sorter_name))
        spike_spray_results = sr['spike_spray_results']
        key = sr['key']

        unit_details = []
        ok = True
        for i, result in enumerate(spike_spray_results):
            if not result.success:
                print(
                    'WARNING: Error creating spike sprays for {}/{}/{}'.format(
                        study_name, rec_name, sorter_name))
                ok = False
                break
            ssobj = ka.load_object(result.outputs.json_out._path)
            if ssobj is None:
                raise Exception('Problem loading spikespray object output.')
            address = mt.saveObject(object=ssobj,
                                    upload_to='spikeforest.kbucket')
            unit = getattr(result, 'unit')
            unit_details.append(
                dict(studyName=study_name,
                     recordingName=rec_name,
                     sorterName=sorter_name,
                     trueUnitId=unit['unit_id'],
                     sortedUnitId=unit['best_unit'],
                     spikeSprayUrl=mt.findFile(
                         path=address,
                         remote_only=True,
                         download_from='spikeforest.kbucket')))

        if ok:
            mt.saveObject(collection='spikeforest',
                          key=key,
                          object=unit_details,
                          upload_to='spikeforest.public')
Ejemplo n.º 28
0
def main():
    parser = argparse.ArgumentParser(description=help_txt, formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('--output_ids', help='Comma-separated list of IDs of the analysis outputs to include in the website.', required=False, default=None)
    parser.add_argument('--upload_to', help='Optional kachery to upload to', required=False, default=None)
    parser.add_argument('--dest_key_path', help='Optional destination key path', required=False, default=None)

    args = parser.parse_args()

    if args.upload_to:
        upload_to = args.upload_to.split(',')
    else:
        upload_to = None

    mt.configDownloadFrom(['spikeforest.kbucket', 'spikeforest.public'])

    if args.output_ids is not None:
        output_ids = args.output_ids.split(',')
    else:
        output_ids = [
            'paired_boyden32c',
            'paired_crcns',
            'paired_mea64c',
            'paired_kampff',
            'paired_monotrode',
            'synth_monotrode',
            'synth_bionet',
            'synth_magland',
            'manual_franklab',
            'synth_mearec_neuronexus',
            'synth_mearec_tetrode',
            # 'SYNTH_MEAREC_TETRODE_b',
            # 'SYNTH_MEAREC_NEURONEXUS_b',
            # 'SYNTH_MEAREC_NEUROPIXELS_b',
            'synth_visapy',
            'hybrid_janelia'
        ]
    print('Using output ids: ', output_ids)

    sorters_to_include = set([
        'HerdingSpikes2',
        'IronClust',
        'IronClust1',
        'IronClust2',
        'IronClust3',
        'IronClust4',
        'IronClust5',
        'IronClust6',
        'IronClust7',
        'IronClust8',
        'JRClust',
        'KiloSort',
        'KiloSort2',
        'Klusta',
        'MountainSort4',
        'SpykingCircus',
        'Tridesclous',
        'Waveclus',
        # 'Yass'
    ])

    print('******************************** LOADING ANALYSIS OUTPUT OBJECTS...')
    studies = []
    study_sets = []
    recordings = []
    sorting_results = []
    for output_id in output_ids:
        print('Loading output object: {}'.format(output_id))
        output_path = ('key://pairio/spikeforest/spikeforest_analysis_results.{}.json').format(output_id)
        obj = mt.loadObject(path=output_path)
        if obj is not None:
            studies = studies + obj['studies']
            print(obj.keys())
            study_sets = study_sets + obj.get('study_sets', [])
            recordings = recordings + obj['recordings']
            sorting_results = sorting_results + obj['sorting_results']
        else:
            raise Exception('Unable to load: {}'.format(output_path))

    # ALGORITHMS
    print('******************************** ASSEMBLING ALGORITHMS...')
    algorithms_by_processor_name = dict()
    Algorithms = []
    basepath = '../../spikeforest/spikeforestsorters/descriptions'
    repo_base_url = 'https://github.com/flatironinstitute/spikeforest/blob/master'
    for item in os.listdir(basepath):
        if item.endswith('.md'):
            alg = frontmatter.load(basepath + '/' + item).to_dict()
            alg['markdown_link'] = repo_base_url + '/spikeforest/spikeforestsorters/descriptions/' + item
            alg['markdown'] = alg['content']
            del alg['content']
            if 'processor_name' in alg:
                algorithms_by_processor_name[alg['processor_name']] = alg
            Algorithms.append(alg)
    print([alg['label'] for alg in Algorithms])

    # # STUDIES
    # print('******************************** ASSEMBLING STUDIES...')
    # studies_by_name = dict()
    # for study in studies:
    #     studies_by_name[study['name']] = study
    #     study['recordings'] = []
    # for recording in recordings:
    #     studies_by_name[recording['studyName']]['recordings'].append(dict(
    #         name=recording['name'],
    #         studyName=recording['study_name'],
    #         directory=recording['directory'],
    #     ))

    Studies = []
    for study in studies:
        Studies.append(dict(
            name=study['name'],
            studySet=study['study_set'],
            description=study['description'],
            recordings=[]
            # the following can be obtained from the other collections
            # numRecordings, sorters, etc...
        ))
    print([S['name'] for S in Studies])

    print('******************************** ASSEMBLING STUDY SETS...')
    study_sets_by_name = dict()
    for study_set in study_sets:
        study_sets_by_name[study_set['name']] = study_set
        study_set['studies'] = []
    studies_by_name = dict()
    for study in studies:
        study0 = dict(
            name=study['name'],
            studySetName=study['study_set'],
            recordings=[]
        )
        study_sets_by_name[study['study_set']]['studies'].append(study0)
        studies_by_name[study0['name']] = study0
    for recording in recordings:
        true_units_info = mt.loadObject(path=recording['summary']['true_units_info'])
        if not true_units_info:
            print(recording['summary']['true_units_info'])
            raise Exception('Unable to load true_units_info for recording {}'.format(recording['name']))
        recording0 = dict(
            name=recording['name'],
            studyName=recording['study'],
            studySetName=studies_by_name[recording['study']]['studySetName'],
            directory=recording['directory'],
            firingsTrue=recording['firings_true'],
            sampleRateHz=recording['summary']['computed_info']['samplerate'],
            numChannels=recording['summary']['computed_info']['num_channels'],
            durationSec=recording['summary']['computed_info']['duration_sec'],
            numTrueUnits=len(true_units_info),
            spikeSign=-1  # TODO: set this properly
        )
        studies_by_name[recording0['studyName']]['recordings'].append(recording0)
    StudySets = []
    for study_set in study_sets:
        StudySets.append(study_set)
    print(StudySets)

    # SORTING RESULTS
    print('******************************** SORTING RESULTS...')
    SortingResults = []
    for sr in sorting_results:
        if sr['sorter']['name'] in sorters_to_include:
            SR = dict(
                recordingName=sr['recording']['name'],
                studyName=sr['recording']['study'],
                sorterName=sr['sorter']['name'],
                recordingDirectory=sr['recording']['directory'],
                firingsTrue=sr['recording']['firings_true'],
                consoleOut=sr['console_out'],
                container=sr['container'],
                cpuTimeSec=sr['execution_stats'].get('elapsed_sec', None),
                returnCode=sr['execution_stats'].get('retcode', 0),  # TODO: in future, the default should not be 0 -- rather it should be a required field of execution_stats
                timedOut=sr['execution_stats'].get('timed_out', False),
                startTime=datetime.fromtimestamp(sr['execution_stats'].get('start_time')).isoformat(),
                endTime=datetime.fromtimestamp(sr['execution_stats'].get('end_time')).isoformat()
            )
            if sr.get('firings', None):
                SR['firings'] = sr['firings']
                if not sr.get('comparison_with_truth', None):
                    print('Warning: comparison with truth not found for sorting result: {} {}/{}'.format(sr['sorter']['name'], sr['recording']['study'], sr['recording']['name']))
                    print('Console output is here: ' + sr['console_out'])
            else:
                print('Warning: firings not found for sorting result: {} {}/{}'.format(sr['sorter']['name'], sr['recording']['study'], sr['recording']['name']))
                print('Console output is here: ' + sr['console_out'])
            SortingResults.append(SR)
    # print('Num unit results:', len(UnitResults))

    # SORTERS
    print('******************************** ASSEMBLING SORTERS...')
    sorters_by_name = dict()
    for sr in sorting_results:
        sorters_by_name[sr['sorter']['name']] = sr['sorter']
    Sorters = []
    sorter_names = sorted(list(sorters_by_name.keys()))
    sorter_names = [sorter_name for sorter_name in sorter_names if sorter_name in sorters_to_include]
    for sorter_name in sorter_names:
        sorter = sorters_by_name[sorter_name]
        alg = algorithms_by_processor_name.get(sorter['processor_name'], dict())
        alg_label = alg.get('label', sorter['processor_name'])
        if sorter['name'] in sorters_to_include:
            Sorters.append(dict(
                name=sorter['name'],
                algorithmName=alg_label,
                processorName=sorter['processor_name'],
                processorVersion='0',  # jfm to provide this
                sortingParameters=sorter['params']
            ))
    print([S['name'] + ':' + S['algorithmName'] for S in Sorters])

    # STUDY ANALYSIS RESULTS
    print('******************************** ASSEMBLING STUDY ANALYSIS RESULTS...')
    StudyAnalysisResults = [
        _assemble_study_analysis_result(
            study_name=study['name'],
            study_set_name=study['study_set'],
            recordings=recordings,
            sorting_results=sorting_results,
            sorter_names=sorter_names
        )
        for study in studies
    ]

    # GENERAL
    print('******************************** ASSEMBLING GENERAL INFO...')
    General = [dict(
        dateUpdated=datetime.now().isoformat(),
        packageVersions=dict(
            mountaintools=pkg_resources.get_distribution("mountaintools").version,
            spikeforest=pkg_resources.get_distribution("spikeforest").version
        )
    )]

    obj = dict(
        mode='spike-front',
        StudySets=StudySets,
        # TrueUnits=TrueUnits,
        # UnitResults=UnitResults,
        SortingResults=SortingResults,
        Sorters=Sorters,
        Algorithms=Algorithms,
        StudyAnalysisResults=StudyAnalysisResults,
        General=General
    )
    address = mt.saveObject(object=obj)
    mt.createSnapshot(path=address, upload_to=upload_to, dest_path=args.dest_key_path)