Ejemplo n.º 1
0
def main(data_dir='/data'):
    data_dir = pathlib.Path(data_dir)
    if not data_dir.exists():
        raise FileNotFoundError(f'Path not found!! {data_dir.as_posix()}')

    # ==================== DEFINE CONSTANTS =====================
    trial_type_mapper = {'CorrR': ('hit', 'right', 'no early'),
                         'CorrL': ('hit', 'left', 'no early'),
                         'ErrR': ('miss', 'right', 'no early'),
                         'ErrL': ('miss', 'left', 'no early'),
                         'NoLickR': ('ignore', 'right', 'no early'),
                         'NoLickL': ('ignore', 'left', 'no early'),
                         'LickEarly': ('non-performing', 'non-performing', 'early')
                        }

    cell_type_mapper = {'PT': 'PT', 'IT': 'IT', 'unidentified': 'N/A', 'in': 'interneuron', 'pn': 'Pyr'}

    task_protocol = {'task': 'audio delay', 'task_protocol': 1}


    kargs = dict(skip_duplicates=True,
                 ignore_extra_fields=True,
                 allow_direct_insert=True)

    # ================== INGESTION OF DATA ==================
    data_files = data_dir.glob('*.nwb')

    for data_file in data_files:
        print(f'-- Read {data_file} --')

        fname = data_file.stem + '.nwb'
        nwb = h5py.File(os.path.join('/data', fname), 'r')

        experimenter = nwb['general/experimenter'][()]
        lab.Person.insert1((experimenter, experimenter), **kargs)

        # ----------- meta-data -----------
        search_result = re.search(
            'nwb_(an\d+)_(\d+)_(.*)_(\d+)um.nwb', fname)
        subject_nickname = search_result.group(1)
        session_date = parse_date(search_result.group(2)).date()
        brain_location = 'left_' + search_result.group(3).lower()
        depth = int(search_result.group(4))

        subject = dict(subject_nickname=subject_nickname,
                       username=experimenter,
                       sex='M', species='Mus musculus',
                       animal_source='Jackson Labs') # inferred from paper

        modified_gene = dict(
            gene_modification='Thy1-GCaMP6s',
            gene_modification_description='GP4.3')

        lab.ModifiedGene.insert1(modified_gene, **kargs)
        lab.Subject.insert1(subject, skip_duplicates=True)
        session_date = parse_date(
            re.search(f'nwb_{subject_nickname}_(\d+)_', fname).group(1))

        if len(experiment.Session & subject):
            session = (lab.Subject & subject).aggr(
                experiment.Session.proj(session_user='******'),
                session_max='max(session)').fetch1('session_max') + 1
        else:
            session = 1

        current_session = dict(
            **subject, session_date=session_date,
            brain_location_name=brain_location,
            session=session)

        if imaging.TrialTrace & current_session:
            print('Data ingested, skipping over...')
            continue

        experiment.Session.insert1(current_session, **kargs)
        experiment.Session.ImagingDepth.insert1(
            dict(**current_session, imaging_depth=depth), **kargs)

        # ---- trial data ----

        print('---- Ingesting trial data ----')
        (session_trials, behavioral_trials, trial_events) = [], [], []

        pole_in = nwb['stimulus/presentation/pole_in/timestamps'][()]
        pole_out = nwb['stimulus/presentation/pole_out/timestamps'][()]
        auditory_cue = nwb['stimulus/presentation/auditory_cue/timestamps'][()]

        for i_trial in range(len(nwb['epochs'])):
            trial_num = i_trial + 1
            trial_path = 'epochs/trial_{0:03}/'.format(trial_num)
            tkey = dict(**current_session, trial=trial_num)

            trial_start = nwb[trial_path + 'start_time'][()]
            trial_stop = nwb[trial_path + 'stop_time'][()]
            sample_start = pole_in[i_trial] - trial_start
            delay_start = pole_out[i_trial] - trial_start
            response_start = auditory_cue[i_trial] - trial_start

            session_trials.append(
                dict(**tkey, start_time=Decimal(trial_start),
                    stop_time=Decimal(trial_stop)))

            outcome, trial_instruction, early = trial_type_mapper[nwb[trial_path + 'tags'][()][0]]
            behavioral_trials.append(
                dict(**tkey, **task_protocol,
                    trial_instruction=trial_instruction,
                    outcome=outcome,
                    early_lick=early))

            for etype, etime in zip(('sample', 'delay', 'go'), (sample_start, delay_start, response_start)):
                if not np.isnan(etime):
                    trial_events.append(
                        dict(**tkey, trial_event_id=len(trial_events)+1,
                            trial_event_type=etype, trial_event_time=etime))
        # insert trial info
        experiment.SessionTrial.insert(session_trials, **kargs)
        experiment.BehaviorTrial.insert(behavioral_trials, **kargs)
        experiment.TrialEvent.insert(trial_events, **kargs)

        # ---- Scan info ----

        print('---- Ingesting imaging data ----')

        imaging.Scan.insert1(
            dict(**current_session,
                image_gcamp=nwb['acquisition/images/green'][()],
                frame_time=nwb['processing/ROIs/ROI_001/timestamps'][()]),
            **kargs)

        tr_events = {tr: (float(stime), float(gotime)) for tr, stime, gotime in
                     zip(*(experiment.SessionTrial * experiment.TrialEvent
                           & current_session & 'trial_event_type = "go"').fetch(
                         'trial', 'start_time', 'trial_event_time'))}

        rois, trial_traces = [], []
        frame_time = nwb['processing/ROIs/ROI_001/timestamps'][()]
        for i_roi in range(len(nwb['processing/ROIs'])):
            roi_idx = i_roi + 1
            roi_path = 'processing/ROIs/ROI_{0:03}/'.format(roi_idx)
            roi_trace = nwb[roi_path + 'fmean'][()]
            neuropil_trace = nwb[roi_path + 'fmean_neuropil'][()]
            roi_trace_corrected = roi_trace - 0.7*neuropil_trace

            rois.append(
                dict(**current_session,
                     roi_idx=roi_idx,
                     roi_trace=roi_trace,
                     neuropil_trace=neuropil_trace,
                     roi_trace_corrected=roi_trace_corrected,
                     cell_type=cell_type_mapper[nwb[roi_path + 'cell_type'][()][0]],
                     ap_position=nwb[roi_path + 'AP_position_from_bregma_in_micron'][()][0],
                     ml_position=nwb[roi_path + 'ML_position_from_bregma_in_micron'][()][0],
                     roi_pixel_list=nwb[roi_path + 'pixel_list'][()],
                     inc=bool(np.mean(roi_trace)/np.mean(neuropil_trace)>1.05)))

            for tr in tr_events.keys():
                if tr in tr_events:
                    go_cue_time = sum(tr_events[tr])
                    go_id = np.abs(frame_time-go_cue_time).argmin()
                    idx = slice(go_id-70, go_id+45, 1)
                    baseline = np.mean(roi_trace_corrected[go_id-70:go_id-64])
                    trial_traces += [
                        dict(**current_session, roi_idx=roi_idx, trial=tr,
                             original_time=frame_time[idx],
                             aligned_time=frame_time[idx]-go_cue_time,
                             aligned_trace=roi_trace[idx],
                             aligned_trace_corrected=roi_trace_corrected[idx],
                             dff=(roi_trace_corrected[idx] - baseline)/baseline)]

        imaging.Scan.Roi.insert(rois, **kargs)
        imaging.TrialTrace.insert(trial_traces, **kargs)
Ejemplo n.º 2
0
def main(meta_data_dir='./data/meta_data', reingest=False):
    meta_data_dir = pathlib.Path(meta_data_dir)
    if not meta_data_dir.exists():
        raise FileNotFoundError(f'Path not found!! {meta_data_dir.as_posix()}')

    # ==================== DEFINE CONSTANTS =====================

    # ---- inferred from paper ----
    hemi = 'left'
    skull_reference = 'bregma'
    photostim_devices = {
        473: 'LaserGem473',
        594: 'LaserCoboltMambo100',
        596: 'LaserCoboltMambo100'
    }

    # ---- from lookup ----
    probe = 'A4x8-5mm-100-200-177'
    electrode_config_name = 'silicon32'

    # ================== INGESTION OF METADATA ==================

    # ---- delete all Sessions ----
    if reingest:
        experiment.Session.delete()

    # ---- insert metadata ----
    meta_data_files = meta_data_dir.glob('*.mat')
    for meta_data_file in tqdm(meta_data_files):
        print(f'-- Read {meta_data_file} --')
        meta_data = sio.loadmat(meta_data_file,
                                struct_as_record=False,
                                squeeze_me=True)['meta_data']

        # ==================== person ====================
        person_key = dict(username=meta_data.experimenters,
                          fullname=meta_data.experimenters)
        lab.Person.insert1(person_key, skip_duplicates=True)

        # ==================== subject gene modification ====================
        modified_genes = (meta_data.animalGeneModification if isinstance(
            meta_data.animalGeneModification,
            (np.ndarray, list)) else [meta_data.animalGeneModification])
        lab.ModifiedGene.insert(
            (dict(gene_modification=g, gene_modification_description=g)
             for g in modified_genes),
            skip_duplicates=True)

        # ==================== subject strain ====================
        animal_strains = (meta_data.animalStrain if isinstance(
            meta_data.animalStrain,
            (np.ndarray, list)) else [meta_data.animalStrain])
        lab.AnimalStrain.insert(zip(animal_strains), skip_duplicates=True)

        # ==================== subject ====================
        animal_id = (meta_data.animalID[0] if isinstance(
            meta_data.animalID, (np.ndarray, list)) else meta_data.animalID)
        animal_source = (meta_data.animalSource[0] if isinstance(
            meta_data.animalSource,
            (np.ndarray, list)) else meta_data.animalSource)
        subject_key = dict(
            subject_id=int(re.search('\d+', animal_id).group()),
            sex=meta_data.sex[0].upper() if len(meta_data.sex) != 0 else 'U',
            species=meta_data.species,
            animal_source=animal_source)
        try:
            date_of_birth = parse_date(meta_data.dateOfBirth)
            subject_key['date_of_birth'] = date_of_birth
        except:
            pass

        lab.AnimalSource.insert1((animal_source, ), skip_duplicates=True)

        with lab.Subject.connection.transaction:
            if subject_key not in lab.Subject.proj():
                lab.Subject.insert1(subject_key)
                lab.Subject.GeneModification.insert(
                    (dict(subject_key, gene_modification=g)
                     for g in modified_genes),
                    ignore_extra_fields=True)
                lab.Subject.Strain.insert(
                    (dict(subject_key, animal_strain=strain)
                     for strain in animal_strains),
                    ignore_extra_fields=True)

        # ==================== session ====================
        session_key = dict(
            subject_key,
            username=person_key['username'],
            session=len(experiment.Session & subject_key) + 1,
            session_date=parse_date(meta_data.dateOfExperiment + ' ' +
                                    meta_data.timeOfExperiment))
        experiment.Session.insert1(session_key, ignore_extra_fields=True)

        print(
            f'\tInsert Session - {session_key["subject_id"]} - {session_key["session_date"]}'
        )

        # ==================== Probe Insertion ====================
        brain_location_key = (experiment.BrainLocation & dict(
            brain_area=meta_data.extracellular.recordingLocation,
            hemisphere=hemi,
            skull_reference=skull_reference)).fetch1('KEY')
        insertion_loc_key = dict(
            brain_location_key,
            ml_location=meta_data.extracellular.recordingCoordinates[0] *
            1000,  # mm to um
            ap_location=meta_data.extracellular.recordingCoordinates[1] *
            1000,  # mm to um
            dv_location=meta_data.extracellular.recordingCoordinates[2]
        )  # already in um

        with ephys.ProbeInsertion.connection.transaction:
            ephys.ProbeInsertion.insert1(dict(
                session_key,
                insertion_number=1,
                probe=probe,
                electrode_config_name=electrode_config_name),
                                         ignore_extra_fields=True)
            ephys.ProbeInsertion.InsertionLocation.insert1(
                dict(session_key,
                     **insertion_loc_key,
                     insertion_number=1,
                     probe=probe,
                     electrode_config_name=electrode_config_name),
                ignore_extra_fields=True)
            ephys.ProbeInsertion.ElectrodeSitePosition.insert(
                (dict(session_key,
                      insertion_number=1,
                      probe=probe,
                      electrode_config_name=electrode_config_name,
                      electrode_group=0,
                      electrode=site_idx + 1,
                      electrode_posx=x * 1000,
                      electrode_posy=y * 1000,
                      electrode_posz=z * 1000)
                 for site_idx, (
                     x, y,
                     z) in enumerate(meta_data.extracellular.siteLocations)),
                ignore_extra_fields=True)

        print(
            f'\tInsert ProbeInsertion - Location: {brain_location_key["brain_location_name"]}'
        )

        # ==================== Virus ====================
        if 'virus' in meta_data._fieldnames and isinstance(
                meta_data.virus, sio.matlab.mio5_params.mat_struct):
            virus_info = dict(
                virus_source=meta_data.virus.virusSource,
                virus=meta_data.virus.virusID,
                virus_lot_number=meta_data.virus.virusLotNumber
                if len(meta_data.virus.virusLotNumber) != 0 else '',
                virus_titer=meta_data.virus.virusTiter.replace('x10', '')
                if meta_data.virus.virusTiter != 'untitered' else None)
            virus.Virus.insert1(virus_info, skip_duplicates=True)

            # -- BrainLocation
            brain_location_key = (experiment.BrainLocation & {
                'brain_area': meta_data.virus.infectionLocation,
                'hemisphere': hemi,
                'skull_reference': skull_reference
            }).fetch1('KEY')
            virus_injection = dict(
                {
                    **virus_info,
                    **subject_key,
                    **brain_location_key
                },
                injection_date=parse_date(meta_data.virus.injectionDate))

            virus.VirusInjection.insert([
                dict(virus_injection,
                     injection_id=inj_idx + 1,
                     ml_location=coord[0] * 1000,
                     ap_location=coord[1] * 1000,
                     dv_location=coord[2] * 1000,
                     injection_volume=vol)
                for inj_idx, (coord, vol) in enumerate(
                    zip(meta_data.virus.infectionCoordinates,
                        meta_data.virus.injectionVolume))
            ],
                                        ignore_extra_fields=True,
                                        skip_duplicates=True)
            print(
                f'\tInsert Virus Injections - Count: {len(meta_data.virus.injectionVolume)}'
            )

        # ==================== Photostim ====================
        if 'photostim' in meta_data._fieldnames and isinstance(
                meta_data.photostim, sio.matlab.mio5_params.mat_struct):
            photostimLocation = (
                meta_data.photostim.photostimLocation if isinstance(
                    meta_data.photostim.photostimLocation, np.ndarray) else
                np.array([meta_data.photostim.photostimLocation]))
            photostimCoordinates = (
                meta_data.photostim.photostimCoordinates if isinstance(
                    meta_data.photostim.photostimCoordinates[0], np.ndarray)
                else np.array([meta_data.photostim.photostimCoordinates]))
            photostim_locs = []
            for ba in set(photostimLocation):
                coords = photostimCoordinates[photostimLocation == ba]
                for coord in coords:
                    photostim_locs.append(
                        (ba, 'left' if coord[1] < 0 else 'right', coord))
                if len(coords) > 1:
                    photostim_locs.append(
                        (ba, 'both',
                         np.array(
                             [coords[0][0],
                              abs(coords[0][1]), coords[0][2]])))

            for stim_idx, (loc, hem, coord) in enumerate(photostim_locs):
                brain_location_key = (experiment.BrainLocation & dict(
                    brain_area=loc,
                    hemisphere=hem,
                    skull_reference=skull_reference)).fetch1('KEY')
                experiment.Photostim.insert1(dict(
                    session_key,
                    **brain_location_key,
                    photo_stim=stim_idx + 1,
                    photostim_device=photostim_devices[
                        meta_data.photostim.photostimWavelength],
                    ml_location=coord[0] * 1000,
                    ap_location=coord[1] * 1000,
                    dv_location=coord[2] * 1000),
                                             ignore_extra_fields=True)

            print(
                f'\tInsert Photostim - Count: {len(meta_data.photostim.photostimLocation)}'
            )
Ejemplo n.º 3
0
def main(data_dir='/data/data_structure'):
    data_dir = pathlib.Path(data_dir)
    if not data_dir.exists():
        raise FileNotFoundError(f'Path not found!! {data_dir.as_posix()}')

    # ==================== DEFINE CONSTANTS =====================

    trial_type_str = ['HitL', 'HitR', 'ErrL', 'ErrR', 'NoLickL', 'NoLickR']
    trial_type_mapper = {
        'HitR': ('hit', 'right'),
        'HitL': ('hit', 'left'),
        'ErrR': ('miss', 'right'),
        'ErrL': ('miss', 'left'),
        'NoLickR': ('ignore', 'right'),
        'NoLickL': ('ignore', 'left')
    }

    cell_type_mapper = {
        'p': 'PT',
        'i': 'IT',
        '': 'N/A',
        'in': 'interneuron',
        'pn': 'Pyr'
    }

    post_resp_tlim = 2  # a trial may last at most 2 seconds after response cue

    task_protocol = {'task': 'audio delay', 'task_protocol': 1}

    insert_kwargs = {
        'ignore_extra_fields': True,
        'allow_direct_insert': True,
        'skip_duplicates': True
    }

    # ================== INGESTION OF DATA ==================
    data_files = data_dir.glob('*.mat')

    for data_file in data_files:
        print(f'-- Read {data_file} --')

        fname = data_file.stem
        subject_nickname = re.search('data_(an\d+)_', fname).group(1)
        session_date = parse_date(
            re.search(subject_nickname + '_(\d+_\d+_\d+_)',
                      fname).group(1).replace('_', ''))
        depth = int(re.search('_(\d+)(_fv|$)', fname).group(1))
        fov = int(re.search('fv(\d+)', fname).group(1)) if re.search(
            'fv(\d+)', fname) else 1

        sessions = (experiment.Session & (lab.Subject & {
            'subject_nickname': subject_nickname
        }).proj() & {
            'session_date': session_date,
            'fov': fov
        } & (experiment.Session.ImagingDepth & {
            'imaging_depth': depth
        }))
        if len(sessions) < 2:
            session_key = sessions.fetch1('KEY')
        else:
            raise Exception('Multiple sessions found for {fname}')

        print(f'\tMatched: {session_key}')

        if imaging.TrialTrace & session_key:
            print('Data ingested, skipping over...')
            continue

        sess_data = sio.loadmat(data_file,
                                struct_as_record=False,
                                squeeze_me=True)['obj']

        # ---- trial data ----
        trial_zip = zip(sess_data.trialIds, sess_data.trialStartTimes,
                        sess_data.trialTypeMat[:6, :].T,
                        sess_data.trialPropertiesHash.value[0],
                        sess_data.trialPropertiesHash.value[1],
                        sess_data.trialPropertiesHash.value[2])

        print('---- Ingesting trial data ----')
        (session_trials, behavior_trials, trial_events) = [], [], []

        for (tr_id, tr_start, trial_type_mtx, sample_start, delay_start,
             response_start) in tqdm(trial_zip):

            tkey = dict(session_key,
                        trial=tr_id,
                        start_time=Decimal(tr_start),
                        stop_time=Decimal(tr_start + response_start +
                                          post_resp_tlim))
            session_trials.append(tkey)

            trial_type = np.array(trial_type_str)[trial_type_mtx.astype(bool)]
            if len(trial_type) == 1:
                outcome, trial_instruction = trial_type_mapper[trial_type[0]]
            else:
                outcome, trial_instruction = 'non-performing', 'non-performing'

            bkey = dict(tkey,
                        **task_protocol,
                        trial_instruction=trial_instruction,
                        outcome=outcome,
                        early_lick='no early')
            behavior_trials.append(bkey)

            for etype, etime in zip(
                ('sample', 'delay', 'go'),
                (sample_start, delay_start, response_start)):
                if not np.isnan(etime):
                    trial_events.append(
                        dict(tkey,
                             trial_event_id=len(trial_events) + 1,
                             trial_event_type=etype,
                             trial_event_time=etime))

        # insert trial info
        experiment.SessionTrial.insert(session_trials, **insert_kwargs)
        experiment.BehaviorTrial.insert(behavior_trials, **insert_kwargs)
        experiment.TrialEvent.insert(trial_events, **insert_kwargs)

        # ---- Scan info ----

        scan = dict(**session_key,
                    image_gcamp=sess_data.images.value[0],
                    image_ctb=sess_data.images.value[1],
                    image_beads=sess_data.images.value[2]
                    if len(sess_data.images.value[2]) else np.nan,
                    frame_time=sess_data.timeSeriesArrayHash.value[0].time)

        cell_type_vec = [
            ct if len(ct) else ''
            for ct in sess_data.timeSeriesArrayHash.value[0].cellType
        ]

        rois = [
            dict(**scan,
                 roi_idx=idx,
                 cell_type=cell_type_mapper[cell_type],
                 roi_trace=roi_trace,
                 neuropil_trace=neuropil_trace,
                 roi_pixel_list=roi_plist,
                 neuropil_pixel_list=neuropil_plist,
                 inc=bool(np.mean(roi_trace) / np.mean(neuropil_trace) > 1.05))
            for (idx, cell_type, roi_trace, neuropil_trace, roi_plist,
                 neuropil_plist) in zip(
                     sess_data.timeSeriesArrayHash.value[0].ids, cell_type_vec,
                     sess_data.timeSeriesArrayHash.value[0].valueMatrix,
                     sess_data.timeSeriesArrayHash.value[1].valueMatrix,
                     sess_data.timeSeriesArrayHash.value[0].pixel_list,
                     sess_data.timeSeriesArrayHash.value[1].pixel_list)
        ]

        imaging.Scan.insert1(scan, **insert_kwargs)
        imaging.Scan.Roi.insert(rois, **insert_kwargs)

        tr_events = {
            tr: (float(stime), float(gotime))
            for tr, stime, gotime in zip(
                *(experiment.SessionTrial * experiment.TrialEvent
                  & session_key & 'trial_event_type = "go"'
                  ).fetch('trial', 'start_time', 'trial_event_time'))
        }

        print('---- Ingesting trial trace ----')
        trials = sess_data.timeSeriesArrayHash.value[0].trial
        frame_time = sess_data.timeSeriesArrayHash.value[0].time

        trial_traces = []
        for roi_id, trace in tqdm(
                zip(sess_data.timeSeriesArrayHash.value[0].ids,
                    sess_data.timeSeriesArrayHash.value[0].valueMatrix)):
            for tr in set(trials):
                if tr in tr_events:
                    go_cue_time = sum(tr_events[tr])
                    go_id = np.abs(frame_time - go_cue_time).argmin()
                    trial_traces += [
                        dict(**scan,
                             roi_idx=roi_id,
                             trial=tr,
                             original_time=frame_time[go_id - 45:go_id + 45],
                             aligned_time=frame_time[go_id - 45:go_id + 45] -
                             go_cue_time,
                             aligned_trace=trace[go_id - 45:go_id + 45],
                             dff=(trace[go_id - 45:go_id + 45] -
                                  np.mean(trace[go_id - 45:go_id - 39])) /
                             np.mean(trace[go_id - 45:go_id - 39]))
                    ]

        imaging.TrialTrace.insert(trial_traces, **insert_kwargs)
Ejemplo n.º 4
0
def main(data_dir='./data/data_structure'):
    data_dir = pathlib.Path(data_dir)
    if not data_dir.exists():
        raise FileNotFoundError(f'Path not found!! {data_dir.as_posix()}')

    # ==================== DEFINE CONSTANTS =====================

    session_suffixes = ['a', 'b', 'c', 'd', 'e']

    trial_type_str = ['HitR', 'HitL', 'ErrR', 'ErrL', 'NoLickR', 'NoLickL']
    trial_type_mapper = {
        'HitR': ('hit', 'right'),
        'HitL': ('hit', 'left'),
        'ErrR': ('miss', 'right'),
        'ErrL': ('miss', 'left'),
        'NoLickR': ('ignore', 'right'),
        'NoLickL': ('ignore', 'left')
    }

    photostim_mapper = {
        1: {
            'brain_area': 'alm',
            'hemi': 'left',
            'duration': 0.5,
            'spot': 1,
            'pre_go_end_time': 1.6,
            'period': 'sample'
        },
        2: {
            'brain_area': 'alm',
            'hemi': 'left',
            'duration': 0.5,
            'spot': 1,
            'pre_go_end_time': 0.8,
            'period': 'early_delay'
        },
        3: {
            'brain_area': 'alm',
            'hemi': 'left',
            'duration': 0.5,
            'spot': 1,
            'pre_go_end_time': 0.3,
            'period': 'middle_delay'
        },
        4: {
            'brain_area': 'alm',
            'hemi': 'left',
            'duration': 0.8,
            'spot': 1,
            'pre_go_end_time': 0.9,
            'period': 'early_delay'
        },
        5: {
            'brain_area': 'alm',
            'hemi': 'right',
            'duration': 0.8,
            'spot': 1,
            'pre_go_end_time': 0.9,
            'period': 'early_delay'
        },
        6: {
            'brain_area': 'alm',
            'hemi': 'both',
            'duration': 0.8,
            'spot': 4,
            'pre_go_end_time': 0.9,
            'period': 'early_delay'
        },
        7: {
            'brain_area': 'alm',
            'hemi': 'both',
            'duration': 0.8,
            'spot': 1,
            'pre_go_end_time': 0.9,
            'period': 'early_delay'
        },
        8: {
            'brain_area': 'alm',
            'hemi': 'left',
            'duration': 0.8,
            'spot': 4,
            'pre_go_end_time': 0.9,
            'period': 'early_delay'
        },
        9: {
            'brain_area': 'alm',
            'hemi': 'right',
            'duration': 0.8,
            'spot': 4,
            'pre_go_end_time': 0.9,
            'period': 'early_delay'
        }
    }

    cell_type_mapper = {'pyramidal': 'Pyr', 'FS': 'FS', 'IT': 'IT', 'PT': 'PT'}

    post_resp_tlim = 2  # a trial may last at most 2 seconds after response cue

    task_protocol = {'task': 'audio delay', 'task_protocol': 1}

    clustering_method = 'manual'

    insert_kwargs = {
        'ignore_extra_fields': True,
        'allow_direct_insert': True,
        'skip_duplicates': True
    }

    # ================== INGESTION OF DATA ==================
    data_files = data_dir.glob('*.mat')

    for data_file in data_files:
        print(f'-- Read {data_file} --')

        fname = data_file.stem
        subject_id = int(re.search('ANM\d+', fname).group().replace('ANM', ''))
        session_date = parse_date(
            re.search('_\d+', fname).group().replace('_', ''))

        sessions = (experiment.Session & {
            'subject_id': subject_id,
            'session_date': session_date
        })
        if len(sessions) < 2:
            session_key = sessions.fetch1('KEY')
        else:
            if fname[-1] in session_suffixes:
                sess_num = sessions.fetch('session', order_by='session')
                session_letter_mapper = {
                    letter: s_no
                    for letter, s_no in zip(session_suffixes, sess_num)
                }
                session_key = (sessions & {
                    'session': session_letter_mapper[fname[-1]]
                }).fetch1('KEY')
            else:
                raise Exception('Multiple sessions found for {fname}')

        print(f'\tMatched: {session_key}')

        if ephys.TrialSpikes & session_key:
            print('Data ingested, skipping over...')
            continue

        sess_data = sio.loadmat(data_file,
                                struct_as_record=False,
                                squeeze_me=True)['obj']

        # get time conversion factor - (-1) to take into account Matlab's 1-based indexing
        ts_time_conversion = time_unit_conversion_factor[
            sess_data.timeUnitNames[
                sess_data.timeSeriesArrayHash.value.timeUnit - 1]]
        trial_time_conversion = time_unit_conversion_factor[
            sess_data.timeUnitNames[sess_data.trialTimeUnit - 1]]
        unit_time_converstion = time_unit_conversion_factor[
            sess_data.timeUnitNames[sess_data.eventSeriesHash.value[0].timeUnit
                                    - 1]]

        # ---- time-series data ----
        ts_tvec = sess_data.timeSeriesArrayHash.value.time * ts_time_conversion
        ts_trial = sess_data.timeSeriesArrayHash.value.trial
        lick_trace = sess_data.timeSeriesArrayHash.value.valueMatrix[:, 0]
        aom_input_trace = sess_data.timeSeriesArrayHash.value.valueMatrix[:, 1]
        laser_power = sess_data.timeSeriesArrayHash.value.valueMatrix[:, 2]

        # ---- trial data ----
        photostims = (experiment.Photostim * experiment.BrainLocation
                      & session_key)

        trial_zip = zip(
            sess_data.trialIds,
            sess_data.trialStartTimes * trial_time_conversion,
            sess_data.trialTypeMat[:6, :].T, sess_data.trialTypeMat[6, :].T,
            sess_data.trialPropertiesHash.value[0] * trial_time_conversion,
            sess_data.trialPropertiesHash.value[1] * trial_time_conversion,
            sess_data.trialPropertiesHash.value[2] * trial_time_conversion,
            sess_data.trialPropertiesHash.value[-1])

        print('---- Ingesting trial data ----')
        (session_trials, behavior_trials, trial_events, photostim_trials,
         photostim_events, photostim_traces,
         lick_traces) = [], [], [], [], [], [], []

        for (tr_id, tr_start, trial_type_mtx, is_early_lick, sample_start,
             delay_start, response_start, photostim_type) in tqdm(trial_zip):

            tkey = dict(
                session_key,
                trial=tr_id,
                start_time=Decimal(tr_start),
                stop_time=Decimal(tr_start + (
                    0 if np.isnan(response_start) else response_start) +
                                  post_resp_tlim))
            session_trials.append(tkey)

            trial_type = np.array(trial_type_str)[trial_type_mtx.astype(bool)]
            if len(trial_type) == 1:
                outcome, trial_instruction = trial_type_mapper[trial_type[0]]
            else:
                outcome, trial_instruction = 'non-performing', 'non-performing'

            bkey = dict(tkey,
                        **task_protocol,
                        trial_instruction=trial_instruction,
                        outcome=outcome,
                        early_lick='early' if is_early_lick else 'no early')
            behavior_trials.append(bkey)

            lick_traces.append(
                dict(bkey,
                     lick_trace=lick_trace[ts_trial == tr_id],
                     lick_trace_timestamps=ts_tvec[ts_trial == tr_id] -
                     tr_start))

            for etype, etime in zip(
                ('sample', 'delay', 'go'),
                (sample_start, delay_start, response_start)):
                if not np.isnan(etime):
                    trial_events.append(
                        dict(tkey,
                             trial_event_id=len(trial_events) + 1,
                             trial_event_type=etype,
                             trial_event_time=etime))

            if photostims and photostim_type != 0:
                pkey = dict(tkey)
                photostim_trials.append(pkey)
                photostim_type = photostim_type.astype(int)
                if photostim_type in photostim_mapper:
                    photstim_detail = photostim_mapper[photostim_type]
                    photostim_key = (
                        photostims & {
                            'brain_area': photstim_detail['brain_area'],
                            'hemisphere': photstim_detail['hemi']
                        })
                    if photostim_key:
                        photostim_key = photostim_key.fetch1('KEY')
                        stim_power = laser_power[ts_trial == tr_id]
                        stim_power = np.where(
                            np.isinf(stim_power), 0,
                            stim_power)  # handle cases where stim power is Inf
                        photostim_events.append(
                            dict(pkey,
                                 **photostim_key,
                                 photostim_event_id=len(photostim_events) + 1,
                                 power=stim_power.max()
                                 if len(stim_power) > 0 else None,
                                 duration=Decimal(photstim_detail['duration']),
                                 photostim_event_time=response_start -
                                 photstim_detail['pre_go_end_time'] -
                                 photstim_detail['duration'],
                                 stim_spot_count=photstim_detail['spot'],
                                 photostim_period=photstim_detail['period']))
                        photostim_traces.append(
                            dict(
                                pkey,
                                aom_input_trace=aom_input_trace[ts_trial ==
                                                                tr_id],
                                laser_power=laser_power[ts_trial == tr_id],
                                photostim_timestamps=ts_tvec[ts_trial == tr_id]
                                - tr_start))

        # insert trial info
        experiment.SessionTrial.insert(session_trials, **insert_kwargs)
        experiment.BehaviorTrial.insert(behavior_trials, **insert_kwargs)
        experiment.PhotostimTrial.insert(photostim_trials, **insert_kwargs)
        experiment.TrialEvent.insert(trial_events, **insert_kwargs)
        experiment.PhotostimEvent.insert(photostim_events, **insert_kwargs)
        experiment.PhotostimTrace.insert(photostim_traces, **insert_kwargs)
        tracking.LickTrace.insert(lick_traces, **insert_kwargs)

        # ---- units ----
        insert_key = (ephys.ProbeInsertion & session_key).fetch1()
        ap, dv = (ephys.ProbeInsertion.InsertionLocation & session_key).fetch1(
            'ap_location', 'dv_location')
        e_sites = {
            e: (y - ap, z - dv)
            for e, y, z in zip(
                *(ephys.ProbeInsertion.ElectrodeSitePosition & session_key
                  ).fetch('electrode', 'electrode_posy', 'electrode_posz'))
        }
        tr_events = {
            tr: (float(stime), float(gotime))
            for tr, stime, gotime in zip(
                *(experiment.SessionTrial * experiment.TrialEvent
                  & session_key & 'trial_event_type = "go"'
                  ).fetch('trial', 'start_time', 'trial_event_time'))
        }

        print('---- Ingesting spike data ----')
        unit_spikes, unit_cell_types, trial_spikes = [], [], []
        for u_name, u_value in tqdm(
                zip(sess_data.eventSeriesHash.keyNames,
                    sess_data.eventSeriesHash.value)):
            unit = int(re.search('\d+', u_name).group())
            electrode = np.unique(u_value.channel)[0]
            spike_times = u_value.eventTimes * unit_time_converstion

            unit_key = dict(insert_key,
                            clustering_method=clustering_method,
                            unit=unit)
            unit_spikes.append(
                dict(unit_key,
                     electrode_group=0,
                     unit_quality='good',
                     electrode=electrode,
                     unit_posx=e_sites[electrode][0],
                     unit_posy=e_sites[electrode][1],
                     spike_times=spike_times,
                     waveform=u_value.waveforms))
            unit_cell_types += [
                dict(unit_key,
                     cell_type=(cell_type_mapper[cell_type]
                                if len(cell_type) > 0 else 'N/A'))
                for cell_type in (
                    u_value.cellType if isinstance(u_value.cellType, (
                        list, np.ndarray)) else [u_value.cellType])
            ]
            # get trial's spike times, shift by start-time, then by go-time -> align to go-time
            trial_spikes += [
                dict(unit_key,
                     trial=tr,
                     spike_times=(spike_times[u_value.eventTrials == tr] -
                                  tr_events[tr][0] - tr_events[tr][1]))
                for tr in set(u_value.eventTrials) if tr in tr_events
            ]

        ephys.Unit.insert(unit_spikes, **insert_kwargs)
        ephys.UnitCellType.insert(unit_cell_types, **insert_kwargs)
        ephys.TrialSpikes.insert(trial_spikes, **insert_kwargs)
Ejemplo n.º 5
0
def main(meta_data_dir='/data/meta_data'):
    meta_data_dir = pathlib.Path(meta_data_dir)
    if not meta_data_dir.exists():
        raise FileNotFoundError(f'Path not found!! {meta_data_dir.as_posix()}')

    # ==================== DEFINE CONSTANTS =====================

    # ---- inferred from paper ----
    hemi = 'left'
    skull_reference = 'bregma'
    photostim_devices = {473: 'LaserGem473', 594: 'LaserCoboltMambo100',  596: 'LaserCoboltMambo100'}

    # ---- from lookup ----
    probe = 'A4x8-5mm-100-200-177'
    electrode_config_name = 'silicon32'

    # ---- virus source mapper ----
    virus_source_mapper = {
        'Upenn core': 'UPenn',
        'Upenn Core': 'UPenn',
        'UPenn': 'UPenn'
    }

    # ---- brain location mapper ----
    brain_location_mapper = {
        'Left ALM': 'left_alm',
        'Left pontine nucleus': 'left_pons'
    }

    # ================== INGESTION OF METADATA ==================

    # ---- delete all Sessions ----
    experiment.Session.delete()

    # ---- insert metadata ----
    meta_data_files = meta_data_dir.glob('*.mat')
    for meta_data_file in tqdm(meta_data_files):
        print(f'-- Read {meta_data_file} --')
        meta_data = sio.loadmat(meta_data_file, struct_as_record = False, squeeze_me=True)['meta_data']

        # ==================== person ====================
        person_key = dict(username=meta_data.experimenters,
                          fullname=meta_data.experimenters)
        lab.Person.insert1(person_key, skip_duplicates=True)

        # ==================== subject gene modification ====================
        modified_genes = (meta_data.animalGeneModification
                          if isinstance(meta_data.animalGeneModification, (np.ndarray, list))
                          else [meta_data.animalGeneModification])
        lab.ModifiedGene.insert((dict(gene_modification=g, gene_modification_description=g)
                                 for g in modified_genes), skip_duplicates=True)

        # ==================== subject strain ====================
        animal_strains = (meta_data.animalStrain
                          if isinstance(meta_data.animalStrain, (np.ndarray, list))
                          else [meta_data.animalStrain])
        lab.AnimalStrain.insert(zip(animal_strains), skip_duplicates=True)

        # ==================== subject ====================
        animal_id = (meta_data.animalID[0]
                     if isinstance(meta_data.animalID, (np.ndarray, list)) else meta_data.animalID)
        animal_source = (meta_data.animalSource[0]
                         if isinstance(meta_data.animalSource, (np.ndarray, list)) else meta_data.animalSource)
        subject_key = dict(subject_id=int(re.search('\d+', animal_id).group()),
                           subject_nickname=re.search('meta_(an\d+)_', meta_data_file.stem).group(1),
                           sex=meta_data.sex[0].upper() if len(meta_data.sex) != 0 else 'U',
                           species=meta_data.species,
                           animal_source=animal_source)
        try:
            date_of_birth = parse_date(meta_data.dateOfBirth)
            subject_key['date_of_birth'] = date_of_birth
        except:
            pass

        lab.AnimalSource.insert1((animal_source,), skip_duplicates=True)

        with lab.Subject.connection.transaction:
            if subject_key not in lab.Subject.proj():
                lab.Subject.insert1(subject_key)
                lab.Subject.GeneModification.insert((dict(subject_key, gene_modification=g) for g in modified_genes),
                                                     ignore_extra_fields=True)
                lab.Subject.Strain.insert((dict(subject_key, animal_strain=strain) for strain in animal_strains),
                                                     ignore_extra_fields=True)

        # ==================== session ====================
        session_key = dict(subject_key, username=person_key['username'],
                           session=len(experiment.Session & subject_key) + 1,
                           session_date=parse_date(meta_data.dateOfExperiment))
        filename = meta_data_file.stem
        if re.search('fv(\d+)', filename):
            session_key.update(fov=int(re.search('fv(\d+)', filename).group(1)))

        experiment.Session.insert1(session_key, ignore_extra_fields=True)

        experiment.Session.ImagingDepth.insert1(
            dict(session_key, imaging_depth=int(re.search('_(\d+)(_fv|$)', filename).group(1))),
            ignore_extra_fields=True)

        print(f'\tInsert Session - {session_key["subject_id"]} - {session_key["session_date"]}')


        # ==================== Virus ====================
        if 'virus' in meta_data._fieldnames and isinstance(meta_data.virus, sio.matlab.mio5_params.mat_struct):
            virus_info = dict(
                virus_source=virus_source_mapper[meta_data.virus.Source],
                virus=meta_data.virus.ID[1],
                virus_titer=meta_data.virus.Titer.replace('x10', '') if meta_data.virus.Titer != 'unknown' else None)

            virus.Virus.insert1(virus_info, skip_duplicates=True)

            # -- BrainLocation
            brain_location_key = (experiment.BrainLocation & {'brain_location_name': brain_location_mapper[meta_data.virus.Location],
                                                              'hemisphere': hemi,
                                                              'skull_reference': skull_reference}).fetch1('KEY')
            virus_injection = dict(
                {**virus_info, **subject_key, **brain_location_key},
                injection_date=parse_date(meta_data.virus.injectionDate))

            virus.VirusInjection.insert([dict(virus_injection,
                                              injection_id=inj_idx + 1,
                                              virus_dilution=float(re.search('1:(\d+) dilution',
                                                                   meta_data.virus.Concentration).group(1)) \
                                                                       if 'Concentraion' in meta_data.virus._fieldnames else None,
                                              ml_location=coord[0] * 1000,
                                              ap_location=coord[1] * 1000,
                                              dv_location=coord[2] * 1000,
                                              injection_volume=vol)
                                         for inj_idx, (coord, vol) in enumerate(zip(meta_data.virus.Coordinates,
                                                                                    meta_data.virus.injectionVolume))],
                                        ignore_extra_fields=True, skip_duplicates=True)
            print(f'\tInsert Virus Injections - Count: {len(meta_data.virus.injectionVolume)}')