예제 #1
0
def export_to_nwb(session_key, nwb_output_dir=default_nwb_output_dir, save=False, overwrite=True):

    this_session = (acquisition.Session & session_key).fetch1()

    # ===============================================================================
    # ============================== META INFORMATION ===============================
    # ===============================================================================

    # -- NWB file - a NWB2.0 file for each session
    experimenter = (acquisition.Session.Experimenter & session_key).fetch1('experimenter')
    nwbfile = NWBFile(identifier='_'.join(
        [this_session['subject'],
         this_session['session_time'].strftime('%Y-%m-%d %H:%M:%S')]),
        related_publications='https://www.nature.com/articles/s41586-018-0633-x',
        experiment_description='Extracelluar recording in ALM',
        session_description='',
        session_start_time=this_session['session_time'],
        file_create_date=datetime.now(tzlocal()),
        experimenter=experimenter,
        institution=institution,
        keywords=['motor planning', 'anterior lateral cortex',
                  'ALM', 'Extracellular recording', 'optogenetics'])
    # -- subject
    subj = (subject.Subject & session_key).fetch1()

    nwbfile.subject = pynwb.file.Subject(
        subject_id=str(this_session['subject']),
        genotype=' x '.join((subject.Zygosity &
                             subj).fetch('allele')) \
                 if len(subject.Zygosity & subj) else 'unknown',
        sex=subj['sex'],
        species=subj['species'],
        date_of_birth=datetime.combine(subj['date_of_birth'], zero_zero_time) if subj['date_of_birth'] else None)
    # -- virus
    nwbfile.virus = json.dumps([{k: str(v) for k, v in virus_injection.items() if k not in subj}
                                for virus_injection in action.VirusInjection * reference.Virus & session_key])

    # ===============================================================================
    # ======================== EXTRACELLULAR & CLUSTERING ===========================
    # ===============================================================================

    """
    In the event of multiple probe recording (i.e. multiple probe insertions), the clustering results
    (and the associated units) are associated with the corresponding probe.
    Each probe insertion is associated with one ElectrodeConfiguration (which may define multiple electrode groups)
    """

    dj_insert_location = ephys.ProbeInsertion

    for probe_insertion in ephys.ProbeInsertion & session_key:
        probe = (reference.Probe & probe_insertion).fetch1()
        electrode_group = nwbfile.create_electrode_group(
                name=probe['probe_type'] + '_g1',
                description='N/A',
                device=nwbfile.create_device(name=probe['probe_type']),
                location=json.dumps({k: str(v) for k, v in (dj_insert_location & session_key).fetch1().items()
                                     if k not in dj_insert_location.primary_key}))

        for chn in (reference.Probe.Channel & probe).fetch(as_dict=True):
            nwbfile.add_electrode(
                id=chn['channel_id']-1,
                group=electrode_group,
                filtering=hardware_filter,
                imp=-1.,
                x=np.nan,
                y=np.nan,
                z=np.nan,
                location=(dj_insert_location & session_key).fetch1('brain_location'))


        # --- unit spike times ---
        nwbfile.add_unit_column(name='cell_type', description='cell type (e.g. fast spiking or pyramidal)')
        nwbfile.add_unit_column(name='sampling_rate', description='sampling rate of the waveform, Hz')

        spk_times_all = np.hstack((ephys.UnitSpikeTimes & probe_insertion).fetch('spike_times'))

        obs_min = np.min(spk_times_all)
        obs_max = np.max(spk_times_all)

        for unit in (ephys.UnitSpikeTimes & probe_insertion).fetch(as_dict=True):
            nwbfile.add_unit(id=unit['unit_id'],
                             electrodes=[unit['channel']-1],
                             electrode_group=electrode_group,
                             cell_type=unit['unit_cell_type'],
                             spike_times=unit['spike_times'],
                             obs_intervals=np.array([[obs_min - 0.001, obs_max + 0.001]]),
                             waveform_mean=np.mean(unit['spike_waveform'], axis=0),
                             waveform_sd=np.std(unit['spike_waveform'], axis=0),
                             sampling_rate=20000)

    # ===============================================================================
    # ============================= PHOTO-STIMULATION ===============================
    # ===============================================================================
    stim_sites = {}
    for photostim in acquisition.PhotoStim * reference.BrainLocation & session_key:

        stim_device = (nwbfile.get_device(photostim['photo_stim_method'])
                       if photostim['photo_stim_method'] in nwbfile.devices
                       else nwbfile.create_device(name=photostim['photo_stim_method']))

        stim_site = pynwb.ogen.OptogeneticStimulusSite(
            name=photostim['brain_location'],
            device=stim_device,
            excitation_lambda=float(photostim['photo_stim_wavelength']),
            location=json.dumps({k: str(v) for k, v in photostim.items()
                                if k in acquisition.PhotoStim.heading.names and k not in acquisition.PhotoStim.primary_key + ['photo_stim_method', 'photo_stim_wavelength']}),
            description='')
        nwbfile.add_ogen_site(stim_site)

    # ===============================================================================
    # =============================== BEHAVIOR TRIALS ===============================
    # ===============================================================================

    # =============== TrialSet ====================
    # NWB 'trial' (of type dynamic table) by default comes with three mandatory attributes: 'start_time' and 'stop_time'
    # Other trial-related information needs to be added in to the trial-table as additional columns (with column name
    # and column description)

    dj_trial = acquisition.Session * behavior.TrialSet.Trial
    skip_adding_columns = acquisition.Session.primary_key + \
        ['trial_id', 'trial_start_idx', 'trial_end_idx', 'trial_start_time', 'session_note']

    if behavior.TrialSet.Trial & session_key:
        # Get trial descriptors from TrialSet.Trial and TrialStimInfo
        trial_columns = [{'name': tag.replace('trial_', ''),
                          'description': re.sub('\s+:|\s+', ' ', re.search(
                              f'(?<={tag})(.*)', str(dj_trial.heading)).group()).strip()}
                         for tag in dj_trial.heading.names
                         if tag not in skip_adding_columns]


        # Add new table columns to nwb trial-table for trial-label
        for c in trial_columns:
            nwbfile.add_trial_column(**c)

        # Add entry to the trial-table
        for trial in (dj_trial & session_key).fetch(as_dict=True):
            trial['start_time'] = float(trial['trial_start_time'])
            trial['stop_time'] = float(trial['trial_start_time']) + 5.0
            trial['trial_pole_in_time'] = trial['start_time'] + trial['trial_pole_in_time']
            trial['trial_pole_out_time'] = trial['start_time'] + trial['trial_pole_out_time']
            trial['trial_cue_time'] = trial['start_time'] + trial['trial_cue_time']
            [trial.pop(k) for k in skip_adding_columns]
            for k in trial.keys():
                if 'trial_' in k:
                    trial[k.replace('trial_', '')] = trial.pop(k)
            nwbfile.add_trial(**trial)


    # =============== Write NWB 2.0 file ===============
    if save:
        save_file_name = ''.join([nwbfile.identifier, '.nwb'])
        if not os.path.exists(nwb_output_dir):
            os.makedirs(nwb_output_dir)
        if not overwrite and os.path.exists(os.path.join(nwb_output_dir, save_file_name)):
            return nwbfile
        with NWBHDF5IO(os.path.join(nwb_output_dir, save_file_name), mode='w') as io:
            io.write(nwbfile)
            print(f'Write NWB 2.0 file: {save_file_name}')

    return nwbfile
예제 #2
0
    def run_conversion(self, nwbfile: NWBFile, metadata: dict):
        """
        Run conversion for this data interface.
        Reads labview experiment behavioral data and adds it to nwbfile.

        Parameters
        ----------
        nwbfile : NWBFile
        metadata : dict
        """
        print("Converting Labview data...")
        # Get list of trial summary files
        dir_behavior_labview = self.source_data['dir_behavior_labview']
        all_files = os.listdir(dir_behavior_labview)
        trials_files = [f for f in all_files if '_sum.txt' in f]
        trials_files.sort()

        # Get session_start_time from first file timestamps
        fpath = os.path.join(dir_behavior_labview, trials_files[0])
        colnames = [
            'Trial', 'StartT', 'EndT', 'Result', 'InitT', 'SpecificResults',
            'ProbLeft', 'OptoDur', 'LRew', 'RRew', 'InterT', 'LTrial',
            'ReactionTime', 'OptoCond', 'OptoTrial'
        ]
        df_0 = pd.read_csv(fpath, sep='\t', index_col=False, names=colnames)
        t0 = df_0['StartT'][0]  # initial time in Labview seconds

        # Add trials
        print("Converting Labview trials data...")
        if nwbfile.trials is not None:
            print(
                'Trials already exist in current nwb file. Labview behavior trials not added.'
            )
        else:
            # Make dataframe
            frames = []
            for f in trials_files:
                fpath = os.path.join(dir_behavior_labview, f)
                frames.append(
                    pd.read_csv(fpath,
                                sep='\t',
                                index_col=False,
                                names=colnames))
            df_trials_summary = pd.concat(frames)

            nwbfile.add_trial_column(
                name='results',
                description=
                "0 means sucess (rewarded trial), 1 means licks during intitial "
                "period, which leads to a failed trial. 2 means early lick failure. 3 means "
                "wrong lick or no response.")
            nwbfile.add_trial_column(
                name='init_t', description="duration of initial delay period.")
            nwbfile.add_trial_column(
                name='specific_results',
                description=
                "Possible outcomes classified based on raw data & meta file (_tr.m)."
            )
            nwbfile.add_trial_column(
                name='prob_left',
                description=
                "probability for left trials in order to keep the number of "
                "left and right trials balanced within the session. ")
            nwbfile.add_trial_column(
                name='opto_dur',
                description="the duration of optical stimulation.")
            nwbfile.add_trial_column(
                name='l_rew_n',
                description="counting the number of left rewards.")
            nwbfile.add_trial_column(
                name='r_rew_n',
                description="counting the number of rightrewards.")
            nwbfile.add_trial_column(name='inter_t',
                                     description="inter-trial delay period.")
            nwbfile.add_trial_column(
                name='l_trial',
                description=
                "trial type (which side the air-puff is applied). 1 means "
                "left-trial, 0 means right-trial")
            nwbfile.add_trial_column(
                name='reaction_time',
                description=
                "if it is a successful trial or wrong lick during response "
                "period trial: ReactionTime = time between the first decision "
                "lick and the beginning of the response period. If it is a failed "
                "trial due to early licks: reaction time = the duration of "
                "the air-puff period (in other words, when the animal licks "
                "during the sample period).")
            nwbfile.add_trial_column(
                name='opto_cond',
                description="0: no opto. 1: opto is on during sample period. "
                "2: opto is on half way through the sample period (0.5s) "
                "and 0.5 during the response period. 3. opto is on during "
                "the response period.")
            nwbfile.add_trial_column(
                name='opto_trial',
                description="1: opto trials. 0: Non-opto trials.")
            for index, row in df_trials_summary.iterrows():
                nwbfile.add_trial(
                    start_time=row['StartT'] - t0,
                    stop_time=row['EndT'] - t0,
                    results=int(row['Result']),
                    init_t=row['InitT'],
                    specific_results=int(row['SpecificResults']),
                    prob_left=row['ProbLeft'],
                    opto_dur=row['OptoDur'],
                    l_rew_n=int(row['LRew']),
                    r_rew_n=int(row['RRew']),
                    inter_t=row['InterT'],
                    l_trial=int(row['LTrial']),
                    reaction_time=int(row['ReactionTime']),
                    opto_cond=int(row['OptoCond']),
                    opto_trial=int(row['OptoTrial']),
                )

        # Get list of files: continuous data
        continuous_files = [f.replace('_sum', '') for f in trials_files]

        # Adds continuous behavioral data
        frames = []
        for f in continuous_files:
            fpath_lick = os.path.join(dir_behavior_labview, f)
            frames.append(pd.read_csv(fpath_lick, sep='\t', index_col=False))
        df_continuous = pd.concat(frames)

        # Behavioral data
        print("Converting Labview behavior data...")
        l1_ts = TimeSeries(name="left_lick",
                           data=df_continuous['Lick 1'].to_numpy(),
                           timestamps=df_continuous['Time'].to_numpy() - t0,
                           description="no description")
        l2_ts = TimeSeries(name="right_lick",
                           data=df_continuous['Lick 2'].to_numpy(),
                           timestamps=df_continuous['Time'].to_numpy() - t0,
                           description="no description")

        nwbfile.add_acquisition(l1_ts)
        nwbfile.add_acquisition(l2_ts)

        # Optogenetics stimulation data
        print("Converting Labview optogenetics data...")
        ogen_device = nwbfile.create_device(
            name=metadata['Ogen']['Device']['name'],
            description=metadata['Ogen']['Device']['description'])

        meta_ogen_site = metadata['Ogen']['OptogeneticStimulusSite']
        ogen_stim_site = OptogeneticStimulusSite(
            name=meta_ogen_site['name'],
            device=ogen_device,
            description=meta_ogen_site['description'],
            excitation_lambda=float(meta_ogen_site['excitation_lambda']),
            location=meta_ogen_site['location'])
        nwbfile.add_ogen_site(ogen_stim_site)

        meta_ogen_series = metadata['Ogen']['OptogeneticSeries']
        ogen_series = OptogeneticSeries(
            name=meta_ogen_series['name'],
            data=df_continuous['Opto'].to_numpy(),
            site=ogen_stim_site,
            timestamps=df_continuous['Time'].to_numpy() - t0,
            description=meta_ogen_series['description'],
        )
        nwbfile.add_stimulus(ogen_series)
예제 #3
0
def export_to_nwb(session_key,
                  nwb_output_dir=default_nwb_output_dir,
                  save=False,
                  overwrite=True):
    this_session = (acquisition.Session & session_key).fetch1()
    # =============== General ====================
    # -- NWB file - a NWB2.0 file for each session
    nwbfile = NWBFile(session_description=this_session['session_note'],
                      identifier='_'.join([
                          this_session['subject_id'],
                          this_session['session_time'].strftime('%Y-%m-%d'),
                          this_session['session_id']
                      ]),
                      session_start_time=this_session['session_time'],
                      file_create_date=datetime.now(tzlocal()),
                      experimenter='; '.join(
                          (acquisition.Session.Experimenter
                           & session_key).fetch('experimenter')),
                      institution=institution,
                      experiment_description=experiment_description,
                      related_publications=related_publications,
                      keywords=keywords)
    # -- subject
    subj = (subject.Subject & session_key).fetch1()
    nwbfile.subject = pynwb.file.Subject(
        subject_id=this_session['subject_id'],
        description=subj['subject_description'],
        genotype=' x '.join(
            (subject.Subject.Allele & session_key).fetch('allele')),
        sex=subj['sex'],
        species=subj['species'])
    # =============== Intracellular ====================
    cell = ((intracellular.Cell
             & session_key).fetch1() if len(intracellular.Cell
                                            & session_key) == 1 else None)
    if cell:
        # metadata
        whole_cell_device = nwbfile.create_device(name=cell['device_name'])
        ic_electrode = nwbfile.create_ic_electrode(
            name=cell['cell_id'],
            device=whole_cell_device,
            description='N/A',
            filtering='N/A',
            location='; '.join([
                f'{k}: {str(v)}'
                for k, v in dict((reference.BrainLocation & cell).fetch1(),
                                 depth=cell['recording_depth']).items()
            ]))
        # acquisition - membrane potential
        mp, mp_timestamps = (intracellular.MembranePotential & cell).fetch1(
            'membrane_potential', 'membrane_potential_timestamps')
        nwbfile.add_acquisition(
            pynwb.icephys.PatchClampSeries(name='PatchClampSeries',
                                           electrode=ic_electrode,
                                           unit='mV',
                                           conversion=1e-3,
                                           gain=1.0,
                                           data=mp,
                                           timestamps=mp_timestamps))

        # acquisition - spike train
        spk, spk_timestamps = (intracellular.SpikeTrain & cell).fetch1(
            'spike_train', 'spike_timestamps')
        nwbfile.add_acquisition(
            pynwb.icephys.PatchClampSeries(name='SpikeTrain',
                                           electrode=ic_electrode,
                                           unit='a.u.',
                                           conversion=1e1,
                                           gain=1.0,
                                           data=spk,
                                           timestamps=spk_timestamps))

    # =============== Behavior ====================
    behavior_data = ((behavior.Behavior & session_key).fetch1()
                     if len(behavior.Behavior & session_key) == 1 else None)
    if behavior_data:
        behav_acq = pynwb.behavior.BehavioralTimeSeries(name='behavior')
        nwbfile.add_acquisition(behav_acq)
        [behavior_data.pop(k) for k in behavior.Behavior.primary_key]
        timestamps = behavior_data.pop('behavior_timestamps')

        # get behavior data description from the comments of table definition
        behavior_descriptions = {
            attr:
            re.search(f'(?<={attr})(.*)#(.*)',
                      str(behavior.Behavior.heading)).groups()[-1].strip()
            for attr in behavior_data
        }

        for b_k, b_v in behavior_data.items():
            behav_acq.create_timeseries(name=b_k,
                                        description=behavior_descriptions[b_k],
                                        unit='a.u.',
                                        conversion=1.0,
                                        data=b_v,
                                        timestamps=timestamps)

    # =============== Photostimulation ====================
    photostim = ((stimulation.PhotoStimulation
                  & session_key).fetch1() if len(stimulation.PhotoStimulation
                                                 & session_key) == 1 else None)
    if photostim:
        photostim_device = (stimulation.PhotoStimDevice & photostim).fetch1()
        stim_device = nwbfile.create_device(
            name=photostim_device['device_name'])
        stim_site = pynwb.ogen.OptogeneticStimulusSite(
            name='-'.join([photostim['hemisphere'],
                           photostim['brain_region']]),
            device=stim_device,
            excitation_lambda=float(
                (stimulation.PhotoStimulationProtocol
                 & photostim).fetch1('photo_stim_excitation_lambda')),
            location='; '.join([
                f'{k}: {str(v)}' for k, v in (reference.ActionLocation
                                              & photostim).fetch1().items()
            ]),
            description=(stimulation.PhotoStimulationProtocol
                         & photostim).fetch1('photo_stim_notes'))
        nwbfile.add_ogen_site(stim_site)

        if photostim['photostim_timeseries'] is not None:
            nwbfile.add_stimulus(
                pynwb.ogen.OptogeneticSeries(
                    name='_'.join([
                        'photostim_on',
                        photostim['photostim_datetime'].strftime(
                            '%Y-%m-%d_%H-%M-%S')
                    ]),
                    site=stim_site,
                    resolution=0.0,
                    conversion=1e-3,
                    data=photostim['photostim_timeseries'],
                    starting_time=photostim['photostim_start_time'],
                    rate=photostim['photostim_sampling_rate']))

    # =============== TrialSet ====================
    # NWB 'trial' (of type dynamic table) by default comes with three mandatory attributes:
    #                                                                       'id', 'start_time' and 'stop_time'.
    # Other trial-related information needs to be added in to the trial-table as additional columns (with column name
    # and column description)
    if len((acquisition.TrialSet & session_key).fetch()) == 1:
        # Get trial descriptors from TrialSet.Trial and TrialStimInfo - remove '_trial' prefix (if any)
        trial_columns = [{
            'name':
            tag.replace('trial_', ''),
            'description':
            re.search(
                f'(?<={tag})(.*)#(.*)',
                str((acquisition.TrialSet.Trial *
                     stimulation.TrialPhotoStimInfo
                     ).heading)).groups()[-1].strip()
        } for tag in acquisition.TrialSet.Trial.heading.names
                         if tag not in acquisition.TrialSet.Trial.primary_key +
                         ['start_time', 'stop_time']]

        # Trial Events - discard 'trial_start' and 'trial_stop' as we already have start_time and stop_time
        # also add `_time` suffix to all events
        trial_events = set(((acquisition.TrialSet.EventTime & session_key) -
                            [{
                                'trial_event': 'trial_start'
                            }, {
                                'trial_event': 'trial_stop'
                            }]).fetch('trial_event'))
        event_names = [{
            'name': e + '_time',
            'description': d
        } for e, d in zip(*(reference.ExperimentalEvent & [{
            'event': k
        } for k in trial_events]).fetch('event', 'description'))]
        # Add new table columns to nwb trial-table for trial-label
        for c in trial_columns + event_names:
            nwbfile.add_trial_column(**c)

        # Add entry to the trial-table
        for trial in (acquisition.TrialSet.Trial
                      & session_key).fetch(as_dict=True):
            events = dict(
                zip(*(acquisition.TrialSet.EventTime & trial
                      & [{
                          'trial_event': e
                      } for e in trial_events]
                      ).fetch('trial_event', 'event_time')))
            # shift event times to be relative to session_start (currently relative to trial_start)
            events = {k: v + trial['start_time'] for k, v in events.items()}

            trial_tag_value = {**trial, **events}
            # rename 'trial_id' to 'id'
            trial_tag_value['id'] = trial_tag_value['trial_id']
            [
                trial_tag_value.pop(k)
                for k in acquisition.TrialSet.Trial.primary_key
            ]

            # Final tweaks: i) add '_time' suffix and ii) remove 'trial_' prefix
            events = {k + '_time': trial_tag_value.pop(k) for k in events}
            trial_attrs = {
                k.replace('trial_', ''): trial_tag_value.pop(k)
                for k in
                [n for n in trial_tag_value if n.startswith('trial_')]
            }

            nwbfile.add_trial(**trial_tag_value, **events, **trial_attrs)

        # =============== Write NWB 2.0 file ===============
        if save:
            save_file_name = ''.join([nwbfile.identifier, '.nwb'])
            if not os.path.exists(nwb_output_dir):
                os.makedirs(nwb_output_dir)
            if not overwrite and os.path.exists(
                    os.path.join(nwb_output_dir, save_file_name)):
                return nwbfile
            with NWBHDF5IO(os.path.join(nwb_output_dir, save_file_name),
                           mode='w') as io:
                io.write(nwbfile)
                print(f'Write NWB 2.0 file: {save_file_name}')

        return nwbfile
예제 #4
0
def export_to_nwb(session_key,
                  nwb_output_dir=default_nwb_output_dir,
                  save=False,
                  overwrite=False):

    this_session = (experiment.Session & session_key).fetch1()
    print(f'Exporting to NWB 2.0 for session: {this_session}...')
    # ===============================================================================
    # ============================== META INFORMATION ===============================
    # ===============================================================================

    sess_desc = session_description_mapper[(
        experiment.ProjectSession & session_key).fetch1('project_name')]

    # -- NWB file - a NWB2.0 file for each session
    nwbfile = NWBFile(
        identifier='_'.join([
            'ANM' + str(this_session['subject_id']),
            this_session['session_date'].strftime('%Y-%m-%d'),
            str(this_session['session'])
        ]),
        session_description='',
        session_start_time=datetime.combine(this_session['session_date'],
                                            zero_zero_time),
        file_create_date=datetime.now(tzlocal()),
        experimenter=this_session['username'],
        institution=institution,
        experiment_description=sess_desc['experiment_description'],
        related_publications=sess_desc['related_publications'],
        keywords=sess_desc['keywords'])

    # -- subject
    subj = (lab.Subject & session_key).aggr(
        lab.Subject.Strain, ...,
        strains='GROUP_CONCAT(animal_strain)').fetch1()
    nwbfile.subject = pynwb.file.Subject(
        subject_id=str(this_session['subject_id']),
        description=
        f'source: {subj["animal_source"]}; strains: {subj["strains"]}',
        genotype=' x '.join((lab.Subject.GeneModification
                             & subj).fetch('gene_modification')),
        sex=subj['sex'],
        species=subj['species'],
        date_of_birth=datetime.combine(subj['date_of_birth'], zero_zero_time)
        if subj['date_of_birth'] else None)
    # -- virus
    nwbfile.virus = json.dumps([{
        k: str(v)
        for k, v in virus_injection.items() if k not in subj
    } for virus_injection in virus.VirusInjection * virus.Virus & session_key])

    # ===============================================================================
    # ======================== EXTRACELLULAR & CLUSTERING ===========================
    # ===============================================================================
    """
    In the event of multiple probe recording (i.e. multiple probe insertions), the clustering results 
    (and the associated units) are associated with the corresponding probe. 
    Each probe insertion is associated with one ElectrodeConfiguration (which may define multiple electrode groups)
    """

    dj_insert_location = ephys.ProbeInsertion.InsertionLocation.aggr(
        ephys.ProbeInsertion.RecordableBrainRegion.proj(
            brain_region='CONCAT(hemisphere, " ", brain_area)'), ...,
        brain_regions='GROUP_CONCAT(brain_region)')

    for probe_insertion in ephys.ProbeInsertion & session_key:
        electrode_config = (lab.ElectrodeConfig & probe_insertion).fetch1()

        electrode_groups = {}
        for electrode_group in lab.ElectrodeConfig.ElectrodeGroup & electrode_config:
            electrode_groups[electrode_group[
                'electrode_group']] = nwbfile.create_electrode_group(
                    name=electrode_config['electrode_config_name'] + '_g' +
                    str(electrode_group['electrode_group']),
                    description='N/A',
                    device=nwbfile.create_device(
                        name=electrode_config['probe']),
                    location=json.dumps({
                        k: str(v)
                        for k, v in (dj_insert_location
                                     & session_key).fetch1().items()
                        if k not in dj_insert_location.primary_key
                    }))

        for chn in (lab.ElectrodeConfig.Electrode * lab.Probe.Electrode
                    & electrode_config).fetch(as_dict=True):
            nwbfile.add_electrode(
                id=chn['electrode'],
                group=electrode_groups[chn['electrode_group']],
                filtering=hardware_filter,
                imp=-1.,
                x=chn['x_coord'] if chn['x_coord'] else np.nan,
                y=chn['y_coord'] if chn['y_coord'] else np.nan,
                z=chn['z_coord'] if chn['z_coord'] else np.nan,
                location=electrode_groups[chn['electrode_group']].location)

        # --- unit spike times ---
        nwbfile.add_unit_column(
            name='sampling_rate',
            description='Sampling rate of the raw voltage traces (Hz)')
        nwbfile.add_unit_column(name='quality',
                                description='unit quality from clustering')
        nwbfile.add_unit_column(
            name='posx',
            description=
            'estimated x position of the unit relative to probe (0,0) (um)')
        nwbfile.add_unit_column(
            name='posy',
            description=
            'estimated y position of the unit relative to probe (0,0) (um)')
        nwbfile.add_unit_column(
            name='cell_type',
            description='cell type (e.g. fast spiking or pyramidal)')

        for unit_key in (ephys.Unit * ephys.UnitCellType
                         & probe_insertion).fetch('KEY'):

            unit = (ephys.Unit * ephys.UnitCellType & probe_insertion
                    & unit_key).proj(..., '-spike_times').fetch1()
            if ephys.TrialSpikes & unit_key:
                obs_intervals = np.array(
                    list(
                        zip(*(ephys.TrialSpikes * experiment.SessionTrial
                              & unit_key).fetch('start_time',
                                                'stop_time')))).astype(float)
                tr_spike_times, tr_start, tr_go = (
                    ephys.TrialSpikes * experiment.SessionTrial *
                    (experiment.TrialEvent & 'trial_event_type = "go"')
                    & unit_key).fetch('spike_times', 'start_time',
                                      'trial_event_time')
                spike_times = np.hstack([
                    spks + float(t_go) + float(t_start) for spks, t_go, t_start
                    in zip(tr_spike_times, tr_start, tr_go)
                ])
            else:  # the case of unavailable `TrialSpikes`
                spike_times = (ephys.Unit & unit_key).fetch1('spike_times')
                obs_intervals = np.array(
                    list(
                        zip(*(experiment.SessionTrial & unit_key).fetch(
                            'start_time', 'stop_time')))).astype(float)
                obs_intervals = [
                    interval for interval in obs_intervals
                    if np.logical_and(spike_times >= interval[0],
                                      spike_times <= interval[-1]).any()
                ]

            # make an electrode table region (which electrode(s) is this unit coming from)
            nwbfile.add_unit(
                id=unit['unit'],
                electrodes=np.where(
                    np.array(nwbfile.electrodes.id.data) ==
                    unit['electrode'])[0],
                electrode_group=electrode_groups[unit['electrode_group']],
                obs_intervals=obs_intervals,
                sampling_rate=ecephys_fs,
                quality=unit['unit_quality'],
                posx=unit['unit_posx'],
                posy=unit['unit_posy'],
                cell_type=unit['cell_type'],
                spike_times=spike_times,
                waveform_mean=np.mean(unit['waveform'], axis=0),
                waveform_sd=np.std(unit['waveform'], axis=0))

    # ===============================================================================
    # ============================= BEHAVIOR TRACKING ===============================
    # ===============================================================================

    if tracking.LickTrace * experiment.SessionTrial & session_key:
        # re-concatenating trialized tracking traces
        lick_traces, time_vecs, trial_starts = (
            tracking.LickTrace * experiment.SessionTrial & session_key).fetch(
                'lick_trace', 'lick_trace_timestamps', 'start_time')
        behav_acq = pynwb.behavior.BehavioralTimeSeries(
            name='BehavioralTimeSeries')
        nwbfile.add_acquisition(behav_acq)
        behav_acq.create_timeseries(
            name='lick_trace',
            unit='a.u.',
            conversion=1.0,
            data=np.hstack(lick_traces),
            description=
            "Time-series of the animal's tongue movement when licking",
            timestamps=np.hstack(time_vecs + trial_starts.astype(float)))

    # ===============================================================================
    # ============================= PHOTO-STIMULATION ===============================
    # ===============================================================================
    stim_sites = {}
    for photostim in experiment.Photostim * experiment.PhotostimBrainRegion * lab.PhotostimDevice & session_key:

        stim_device = (nwbfile.get_device(photostim['photostim_device'])
                       if photostim['photostim_device'] in nwbfile.devices else
                       nwbfile.create_device(
                           name=photostim['photostim_device']))

        stim_site = pynwb.ogen.OptogeneticStimulusSite(
            name=photostim['stim_laterality'] + ' ' +
            photostim['stim_brain_area'],
            device=stim_device,
            excitation_lambda=float(photostim['excitation_wavelength']),
            location=json.dumps([{
                k: v
                for k, v in stim_locs.items()
                if k not in experiment.Photostim.primary_key
            } for stim_locs in (experiment.Photostim.PhotostimLocation.proj(
                ..., '-brain_area')
                                & photostim).fetch(as_dict=True)],
                                default=str),
            description='')
        nwbfile.add_ogen_site(stim_site)
        stim_sites[photostim['photo_stim']] = stim_site

    # re-concatenating trialized photostim traces
    dj_photostim = (experiment.PhotostimTrace * experiment.SessionTrial *
                    experiment.PhotostimEvent * experiment.Photostim
                    & session_key)

    for photo_stim, stim_site in stim_sites.items():
        if dj_photostim & {'photo_stim': photo_stim}:
            aom_input_trace, laser_power, time_vecs, trial_starts = (
                dj_photostim & {
                    'photo_stim': photo_stim
                }).fetch('aom_input_trace', 'laser_power',
                         'photostim_timestamps', 'start_time')

            aom_series = pynwb.ogen.OptogeneticSeries(
                name=stim_site.name + '_aom_input_trace',
                site=stim_site,
                conversion=1e-3,
                data=np.hstack(aom_input_trace),
                timestamps=np.hstack(time_vecs + trial_starts.astype(float)))
            laser_series = pynwb.ogen.OptogeneticSeries(
                name=stim_site.name + '_laser_power',
                site=stim_site,
                conversion=1e-3,
                data=np.hstack(laser_power),
                timestamps=np.hstack(time_vecs + trial_starts.astype(float)))

            nwbfile.add_stimulus(aom_series)
            nwbfile.add_stimulus(laser_series)

    # ===============================================================================
    # =============================== BEHAVIOR TRIALS ===============================
    # ===============================================================================

    # =============== TrialSet ====================
    # NWB 'trial' (of type dynamic table) by default comes with three mandatory attributes: 'start_time' and 'stop_time'
    # Other trial-related information needs to be added in to the trial-table as additional columns (with column name
    # and column description)

    dj_trial = experiment.SessionTrial * experiment.BehaviorTrial
    skip_adding_columns = experiment.Session.primary_key + [
        'trial_uid', 'trial'
    ]

    if experiment.SessionTrial & session_key:
        # Get trial descriptors from TrialSet.Trial and TrialStimInfo
        trial_columns = [{
            'name':
            tag,
            'description':
            re.sub('\s+:|\s+', ' ',
                   re.search(f'(?<={tag})(.*)',
                             str(dj_trial.heading)).group()).strip()
        } for tag in dj_trial.heading.names if tag not in skip_adding_columns +
                         ['start_time', 'stop_time']]

        # Add new table columns to nwb trial-table for trial-label
        for c in trial_columns:
            nwbfile.add_trial_column(**c)

        # Add entry to the trial-table
        for trial in (dj_trial & session_key).fetch(as_dict=True):
            trial['start_time'] = float(trial['start_time'])
            trial['stop_time'] = float(
                trial['stop_time']) if trial['stop_time'] else np.nan
            trial['id'] = trial['trial']  # rename 'trial_id' to 'id'
            [trial.pop(k) for k in skip_adding_columns]
            nwbfile.add_trial(**trial)

    # ===============================================================================
    # =============================== BEHAVIOR TRIAL EVENTS ==========================
    # ===============================================================================

    behav_event = pynwb.behavior.BehavioralEvents(name='BehavioralEvents')
    nwbfile.add_acquisition(behav_event)

    for trial_event_type in (experiment.TrialEventType & experiment.TrialEvent
                             & session_key).fetch('trial_event_type'):
        event_times, trial_starts = (
            experiment.TrialEvent * experiment.SessionTrial
            & session_key & {
                'trial_event_type': trial_event_type
            }).fetch('trial_event_time', 'start_time')
        if len(event_times) > 0:
            event_times = np.hstack(
                event_times.astype(float) + trial_starts.astype(float))
            behav_event.create_timeseries(name=trial_event_type,
                                          unit='a.u.',
                                          conversion=1.0,
                                          data=np.full_like(event_times, 1),
                                          timestamps=event_times)

    photostim_event_time, trial_starts, photo_stim, power, duration = (
        experiment.PhotostimEvent * experiment.SessionTrial
        & session_key).fetch('photostim_event_time', 'start_time',
                             'photo_stim', 'power', 'duration')

    if len(photostim_event_time) > 0:
        behav_event.create_timeseries(
            name='photostim_start_time',
            unit='a.u.',
            conversion=1.0,
            data=power,
            timestamps=photostim_event_time.astype(float) +
            trial_starts.astype(float),
            control=photo_stim.astype('uint8'),
            control_description=stim_sites)
        behav_event.create_timeseries(
            name='photostim_stop_time',
            unit='a.u.',
            conversion=1.0,
            data=np.full_like(photostim_event_time, 0),
            timestamps=photostim_event_time.astype(float) +
            duration.astype(float) + trial_starts.astype(float),
            control=photo_stim.astype('uint8'),
            control_description=stim_sites)

    # =============== Write NWB 2.0 file ===============
    if save:
        save_file_name = ''.join([nwbfile.identifier, '.nwb'])
        if not os.path.exists(nwb_output_dir):
            os.makedirs(nwb_output_dir)
        if not overwrite and os.path.exists(
                os.path.join(nwb_output_dir, save_file_name)):
            return nwbfile
        with NWBHDF5IO(os.path.join(nwb_output_dir, save_file_name),
                       mode='w') as io:
            io.write(nwbfile)
            print(f'Write NWB 2.0 file: {save_file_name}')

    return nwbfile
예제 #5
0
def export_to_nwb(session_key,
                  nwb_output_dir=default_nwb_output_dir,
                  save=False,
                  overwrite=True):
    this_session = (acquisition.Session & session_key).fetch1()

    identifier = '_'.join([
        this_session['subject_id'],
        this_session['session_time'].strftime('%Y-%m-%d'),
        this_session['session_id']
    ])

    # =============== General ====================
    # -- NWB file - a NWB2.0 file for each session
    nwbfile = NWBFile(session_description=this_session['session_note'],
                      identifier=identifier,
                      session_start_time=this_session['session_time'],
                      file_create_date=datetime.now(tzlocal()),
                      experimenter='; '.join(
                          (acquisition.Session.Experimenter
                           & session_key).fetch('experimenter')),
                      institution=institution,
                      experiment_description=experiment_description,
                      related_publications=related_publications,
                      keywords=keywords)
    # -- subject
    subj = (subject.Subject & session_key).fetch1()
    nwbfile.subject = pynwb.file.Subject(
        subject_id=this_session['subject_id'],
        description=subj['subject_description'],
        genotype=' x '.join(
            (subject.Subject.Allele & session_key).fetch('allele')),
        sex=subj['sex'],
        species=subj['species'])

    # =============== Intracellular ====================
    cell = ((intracellular.Cell
             & session_key).fetch1() if len(intracellular.Cell
                                            & session_key) == 1 else None)
    if cell:
        # metadata
        whole_cell_device = nwbfile.create_device(name=cell['device_name'])
        ic_electrode = nwbfile.create_ic_electrode(
            name=cell['cell_id'],
            device=whole_cell_device,
            description='N/A',
            filtering='N/A',
            location='; '.join([
                f'{k}: {str(v)}'
                for k, v in dict((reference.BrainLocation & cell).fetch1(),
                                 depth=cell['cell_depth']).items()
            ]))
        # acquisition - membrane potential
        mp, mp_wo_spike, mp_start_time, mp_fs = (
            intracellular.MembranePotential & cell).fetch1(
                'membrane_potential', 'membrane_potential_wo_spike',
                'membrane_potential_start_time',
                'membrane_potential_sampling_rate')
        nwbfile.add_acquisition(
            pynwb.icephys.PatchClampSeries(name='PatchClampSeries',
                                           electrode=ic_electrode,
                                           unit='mV',
                                           conversion=1e-3,
                                           gain=1.0,
                                           data=mp,
                                           starting_time=mp_start_time,
                                           rate=mp_fs))
        # acquisition - current injection
        if (intracellular.CurrentInjection & cell):
            current_injection, ci_start_time, ci_fs = (
                intracellular.CurrentInjection & cell).fetch1(
                    'current_injection', 'current_injection_start_time',
                    'current_injection_sampling_rate')
            nwbfile.add_stimulus(
                pynwb.icephys.CurrentClampStimulusSeries(
                    name='CurrentClampStimulus',
                    electrode=ic_electrode,
                    conversion=1e-9,
                    gain=1.0,
                    data=current_injection,
                    starting_time=ci_start_time,
                    rate=ci_fs))

        # analysis - membrane potential without spike
        mp_rmv_spike = nwbfile.create_processing_module(
            name='icephys', description='Spike removal')
        mp_rmv_spike.add_data_interface(
            pynwb.icephys.PatchClampSeries(name='icephys',
                                           electrode=ic_electrode,
                                           unit='mV',
                                           conversion=1e-3,
                                           gain=1.0,
                                           data=mp_wo_spike,
                                           starting_time=mp_start_time,
                                           rate=mp_fs))

    # =============== Extracellular ====================
    probe_insertion = ((extracellular.ProbeInsertion
                        & session_key).fetch1() if extracellular.ProbeInsertion
                       & session_key else None)
    if probe_insertion:
        probe = nwbfile.create_device(name=probe_insertion['probe_name'])
        electrode_group = nwbfile.create_electrode_group(name='; '.join([
            f'{probe_insertion["probe_name"]}: {str(probe_insertion["channel_counts"])}'
        ]),
                                                         description='N/A',
                                                         device=probe,
                                                         location='; '.join([
                                                             f'{k}: {str(v)}'
                                                             for k, v in
                                                             (reference.
                                                              BrainLocation
                                                              & probe_insertion
                                                              ).fetch1().items(
                                                              )
                                                         ]))

        for chn in (reference.Probe.Channel
                    & probe_insertion).fetch(as_dict=True):
            nwbfile.add_electrode(
                id=chn['channel_id'],
                group=electrode_group,
                filtering=hardware_filter,
                imp=-1.,
                x=0.0,  # not available from data
                y=0.0,  # not available from data
                z=0.0,  # not available from data
                location=electrode_group.location)

        # --- unit spike times ---
        nwbfile.add_unit_column(
            name='sampling_rate',
            description='Sampling rate of the raw voltage traces (Hz)')
        nwbfile.add_unit_column(name='depth',
                                description='depth this unit (mm)')
        nwbfile.add_unit_column(name='spike_width',
                                description='spike width of this unit (ms)')
        nwbfile.add_unit_column(
            name='cell_type',
            description='cell type (e.g. wide width, narrow width spiking)')

        for unit in (extracellular.UnitSpikeTimes
                     & probe_insertion).fetch(as_dict=True):
            # make an electrode table region (which electrode(s) is this unit coming from)
            nwbfile.add_unit(
                id=unit['unit_id'],
                electrodes=(unit['channel_id'] if isinstance(
                    unit['channel_id'], np.ndarray) else [unit['channel_id']]),
                depth=unit['unit_depth'],
                sampling_rate=ecephys_fs,
                spike_width=unit['unit_spike_width'],
                cell_type=unit['unit_cell_type'],
                spike_times=unit['spike_times'],
                waveform_mean=unit['spike_waveform'])

    # =============== Behavior ====================
    # Note: for this study, raw behavioral data were not available, only trialized data were provided
    # here, we reconstruct raw behavioral data by concatenation
    trial_seg_setting = (analysis.TrialSegmentationSetting
                         & 'trial_seg_setting=0').fetch1()
    seg_behav_query = (
        behavior.TrialSegmentedLickTrace * acquisition.TrialSet.Trial *
        (analysis.RealignedEvent.RealignedEventTime
         & 'trial_event="trial_start"')
        & session_key & trial_seg_setting)

    if seg_behav_query:
        behav_acq = pynwb.behavior.BehavioralTimeSeries(name='lick_times')
        nwbfile.add_acquisition(behav_acq)
        seg_behav = pd.DataFrame(
            seg_behav_query.fetch('start_time', 'realigned_event_time',
                                  'segmented_lick_left_on',
                                  'segmented_lick_left_off',
                                  'segmented_lick_right_on',
                                  'segmented_lick_right_off')).T
        seg_behav.columns = [
            'start_time', 'realigned_event_time', 'segmented_lick_left_on',
            'segmented_lick_left_off', 'segmented_lick_right_on',
            'segmented_lick_right_off'
        ]
        for behav_name in [
                'lick_left_on', 'lick_left_off', 'lick_right_on',
                'lick_right_off'
        ]:
            lick_times = np.hstack(r['segmented_' + behav_name] -
                                   r.realigned_event_time + r.start_time
                                   for _, r in seg_behav.iterrows())
            behav_acq.create_timeseries(name=behav_name,
                                        unit='a.u.',
                                        conversion=1.0,
                                        data=np.full_like(lick_times, 1),
                                        timestamps=lick_times)

    # =============== Photostimulation ====================
    photostim = ((stimulation.PhotoStimulation
                  & session_key).fetch1() if stimulation.PhotoStimulation
                 & session_key else None)
    if photostim:
        photostim_device = (stimulation.PhotoStimDevice & photostim).fetch1()
        stim_device = nwbfile.create_device(
            name=photostim_device['device_name'])
        stim_site = pynwb.ogen.OptogeneticStimulusSite(
            name='-'.join([photostim['hemisphere'],
                           photostim['brain_region']]),
            device=stim_device,
            excitation_lambda=float(
                (stimulation.PhotoStimProtocol
                 & photostim).fetch1('photo_stim_excitation_lambda')),
            location='; '.join([
                f'{k}: {str(v)}' for k, v in (reference.ActionLocation
                                              & photostim).fetch1().items()
            ]),
            description=(stimulation.PhotoStimProtocol
                         & photostim).fetch1('photo_stim_notes'))
        nwbfile.add_ogen_site(stim_site)

        if photostim['photostim_timeseries'] is not None:
            nwbfile.add_stimulus(
                pynwb.ogen.OptogeneticSeries(
                    name='_'.join([
                        'photostim_on',
                        photostim['photostim_datetime'].strftime(
                            '%Y-%m-%d_%H-%M-%S')
                    ]),
                    site=stim_site,
                    resolution=0.0,
                    conversion=1e-3,
                    data=photostim['photostim_timeseries'],
                    starting_time=photostim['photostim_start_time'],
                    rate=photostim['photostim_sampling_rate']))

    # =============== TrialSet ====================
    # NWB 'trial' (of type dynamic table) by default comes with three mandatory attributes:
    #                                                                       'id', 'start_time' and 'stop_time'.
    # Other trial-related information needs to be added in to the trial-table as additional columns (with column name
    # and column description)
    if acquisition.TrialSet & session_key:
        # Get trial descriptors from TrialSet.Trial and TrialStimInfo - remove '_trial' prefix (if any)
        trial_columns = [
            {
                'name':
                tag.replace('trial_', ''),
                'description':
                re.search(
                    f'(?<={tag})(.*)#(.*)',
                    str((acquisition.TrialSet.Trial *
                         stimulation.TrialPhotoStimParam
                         ).heading)).groups()[-1].strip()
            } for tag in (acquisition.TrialSet.Trial *
                          stimulation.TrialPhotoStimParam).heading.names
            if tag not in (acquisition.TrialSet.Trial
                           & stimulation.TrialPhotoStimParam).primary_key +
            ['start_time', 'stop_time']
        ]

        # Trial Events - discard 'trial_start' and 'trial_stop' as we already have start_time and stop_time
        # also add `_time` suffix to all events
        trial_events = set(((acquisition.TrialSet.EventTime & session_key) -
                            [{
                                'trial_event': 'trial_start'
                            }, {
                                'trial_event': 'trial_stop'
                            }]).fetch('trial_event'))
        event_names = [{
            'name': e + '_time',
            'description': d + ' - (s) relative to trial start time'
        } for e, d in zip(*(reference.ExperimentalEvent & [{
            'event': k
        } for k in trial_events]).fetch('event', 'description'))]
        # Add new table columns to nwb trial-table for trial-label
        for c in trial_columns + event_names:
            nwbfile.add_trial_column(**c)

        photostim_tag_default = {
            tag: ''
            for tag in stimulation.TrialPhotoStimParam.heading.names
            if tag not in stimulation.TrialPhotoStimParam.primary_key
        }

        # Add entry to the trial-table
        for trial in (acquisition.TrialSet.Trial
                      & session_key).fetch(as_dict=True):
            events = dict(
                zip(*(acquisition.TrialSet.EventTime & trial
                      & [{
                          'trial_event': e
                      } for e in trial_events]
                      ).fetch('trial_event', 'event_time')))

            trial_tag_value = ({
                **trial,
                **events,
                **(stimulation.TrialPhotoStimParam & trial).fetch1()
            } if (stimulation.TrialPhotoStimParam & trial) else {
                **trial,
                **events,
                **photostim_tag_default
            })

            trial_tag_value['id'] = trial_tag_value[
                'trial_id']  # rename 'trial_id' to 'id'
            [
                trial_tag_value.pop(k)
                for k in acquisition.TrialSet.Trial.primary_key
            ]

            # convert None to np.nan since nwb fields does not take None
            for k, v in trial_tag_value.items():
                trial_tag_value[k] = v if v is not None else np.nan

            trial_tag_value['delay_duration'] = float(
                trial_tag_value['delay_duration'])  # convert Decimal to float

            # Final tweaks: i) add '_time' suffix and ii) remove 'trial_' prefix
            events = {k + '_time': trial_tag_value.pop(k) for k in events}
            trial_attrs = {
                k.replace('trial_', ''): trial_tag_value.pop(k)
                for k in
                [n for n in trial_tag_value if n.startswith('trial_')]
            }

            nwbfile.add_trial(**trial_tag_value, **events, **trial_attrs)

    # =============== Write NWB 2.0 file ===============
    if save:
        save_file_name = ''.join([nwbfile.identifier, '.nwb'])
        if not os.path.exists(nwb_output_dir):
            os.makedirs(nwb_output_dir)
        if not overwrite and os.path.exists(
                os.path.join(nwb_output_dir, save_file_name)):
            return nwbfile
        with NWBHDF5IO(os.path.join(nwb_output_dir, save_file_name),
                       mode='w') as io:
            io.write(nwbfile)
            print(f'Write NWB 2.0 file: {save_file_name}')

    return nwbfile
예제 #6
0
    if photostim:
        photostim_device = (stimulation.PhotoStimDevice & photostim).fetch1()
        stim_device = nwbfile.create_device(
            name=photostim_device['device_name'])
        stim_site = pynwb.ogen.OptogeneticStimulusSite(
            name='-'.join([photostim['hemisphere'],
                           photostim['brain_region']]),
            device=stim_device,
            excitation_lambda=float(photostim['photo_stim_excitation_lambda']),
            location='; '.join([
                f'{k}: {str(v)}' for k, v in (reference.ActionLocation
                                              & photostim).fetch1().items()
            ]),
            description=(stimulation.PhotoStimulationInfo
                         & photostim).fetch1('photo_stim_notes'))
        nwbfile.add_ogen_site(stim_site)

        if photostim['photostim_timeseries'] is not None:
            nwbfile.add_stimulus(
                pynwb.ogen.OptogeneticSeries(
                    name='_'.join([
                        'photostim_on',
                        photostim['photostim_datetime'].strftime(
                            '%Y-%m-%d_%H-%M-%S')
                    ]),
                    site=stim_site,
                    unit='mW',
                    resolution=0.0,
                    conversion=1e-6,
                    data=photostim['photostim_timeseries'],
                    starting_time=photostim['photostim_start_time'],
def export_to_nwb(session_key,
                  nwb_output_dir=default_nwb_output_dir,
                  save=False,
                  overwrite=True):
    this_session = (acquisition.Session & session_key).fetch1()
    # =============== General ====================
    # -- NWB file - a NWB2.0 file for each session
    nwbfile = NWBFile(
        session_description=this_session['session_note'],
        identifier='_'.join([
            this_session['subject_id'],
            this_session['session_time'].strftime('%Y-%m-%d_%H-%M-%S')
        ]),
        session_start_time=this_session['session_time'],
        file_create_date=datetime.now(tzlocal()),
        experimenter='; '.join((acquisition.Session.Experimenter
                                & session_key).fetch('experimenter')),
        institution=institution,
        experiment_description=experiment_description,
        related_publications=related_publications,
        keywords=keywords)
    # -- subject
    subj = (subject.Subject & session_key).fetch1()
    nwbfile.subject = pynwb.file.Subject(
        subject_id=this_session['subject_id'],
        description=subj['subject_description'],
        genotype=' x '.join((subject.Subject.Allele
                             & session_key).fetch('allele')),
        sex=subj['sex'],
        species=subj['species'])
    # =============== Intracellular ====================
    cell = ((intracellular.Cell & session_key).fetch1() if intracellular.Cell
            & session_key else None)
    if cell:
        # metadata
        cell = (intracellular.Cell & session_key).fetch1()
        whole_cell_device = nwbfile.create_device(name=cell['device_name'])
        ic_electrode = nwbfile.create_ic_electrode(
            name=cell['cell_id'],
            device=whole_cell_device,
            description='N/A',
            filtering='low-pass: 10kHz',
            location='; '.join([
                f'{k}: {str(v)}'
                for k, v in (reference.ActionLocation & cell).fetch1().items()
            ]))
        # acquisition - membrane potential
        mp, mp_wo_spike, mp_start_time, mp_fs = (
            intracellular.MembranePotential & cell).fetch1(
                'membrane_potential', 'membrane_potential_wo_spike',
                'membrane_potential_start_time',
                'membrane_potential_sampling_rate')
        nwbfile.add_acquisition(
            pynwb.icephys.PatchClampSeries(name='PatchClampSeries',
                                           electrode=ic_electrode,
                                           unit='mV',
                                           conversion=1e-3,
                                           gain=1.0,
                                           data=mp,
                                           starting_time=mp_start_time,
                                           rate=mp_fs))
        # acquisition - current injection
        current_injection, ci_start_time, ci_fs = (
            intracellular.CurrentInjection & cell).fetch1(
                'current_injection', 'current_injection_start_time',
                'current_injection_sampling_rate')
        nwbfile.add_stimulus(
            pynwb.icephys.CurrentClampStimulusSeries(
                name='CurrentClampStimulus',
                electrode=ic_electrode,
                conversion=1e-9,
                gain=1.0,
                data=current_injection,
                starting_time=ci_start_time,
                rate=ci_fs))

        # analysis - membrane potential without spike
        mp_rmv_spike = nwbfile.create_processing_module(
            name='icephys', description='Spike removal')
        mp_rmv_spike.add_data_interface(
            pynwb.icephys.PatchClampSeries(name='icephys',
                                           electrode=ic_electrode,
                                           unit='mV',
                                           conversion=1e-3,
                                           gain=1.0,
                                           data=mp_wo_spike,
                                           starting_time=mp_start_time,
                                           rate=mp_fs))

    # =============== Extracellular ====================
    probe_insertion = ((extracellular.ProbeInsertion
                        & session_key).fetch1() if extracellular.ProbeInsertion
                       & session_key else None)
    if probe_insertion:
        probe = nwbfile.create_device(name=probe_insertion['probe_name'])
        electrode_group = nwbfile.create_electrode_group(name='; '.join([
            f'{probe_insertion["probe_name"]}: {str(probe_insertion["channel_counts"])}'
        ]),
                                                         description='N/A',
                                                         device=probe,
                                                         location='; '.join([
                                                             f'{k}: {str(v)}'
                                                             for k, v in
                                                             (reference.
                                                              ActionLocation
                                                              & probe_insertion
                                                              ).fetch1().items(
                                                              )
                                                         ]))

        for chn in (reference.Probe.Channel
                    & probe_insertion).fetch(as_dict=True):
            nwbfile.add_electrode(id=chn['channel_id'],
                                  group=electrode_group,
                                  filtering=hardware_filter,
                                  imp=-1.,
                                  x=chn['channel_x_pos'],
                                  y=chn['channel_y_pos'],
                                  z=chn['channel_z_pos'],
                                  location=electrode_group.location)

        # --- unit spike times ---
        nwbfile.add_unit_column(
            name='sampling_rate',
            description='Sampling rate of the raw voltage traces (Hz)')
        nwbfile.add_unit_column(name='unit_x',
                                description='x-coordinate of this unit (mm)')
        nwbfile.add_unit_column(name='unit_y',
                                description='y-coordinate of this unit (mm)')
        nwbfile.add_unit_column(name='unit_z',
                                description='z-coordinate of this unit (mm)')
        nwbfile.add_unit_column(
            name='cell_type',
            description='cell type (e.g. wide width, narrow width spiking)')

        for unit in (extracellular.UnitSpikeTimes
                     & probe_insertion).fetch(as_dict=True):
            # make an electrode table region (which electrode(s) is this unit coming from)
            nwbfile.add_unit(
                id=unit['unit_id'],
                electrodes=(unit['channel_id'] if isinstance(
                    unit['channel_id'], np.ndarray) else [unit['channel_id']]),
                sampling_rate=ecephys_fs,
                unit_x=unit['unit_x'],
                unit_y=unit['unit_y'],
                unit_z=unit['unit_z'],
                cell_type=unit['unit_cell_type'],
                spike_times=unit['spike_times'],
                waveform_mean=np.mean(unit['spike_waveform'], axis=0),
                waveform_sd=np.std(unit['spike_waveform'], axis=0))

    # =============== Behavior ====================
    behavior_data = ((behavior.LickTrace
                      & session_key).fetch1() if behavior.LickTrace
                     & session_key else None)
    if behavior_data:
        behav_acq = pynwb.behavior.BehavioralTimeSeries(name='lick_trace')
        nwbfile.add_acquisition(behav_acq)
        [behavior_data.pop(k) for k in behavior.LickTrace.primary_key]
        lt_start_time = behavior_data.pop('lick_trace_start_time')
        lt_fs = behavior_data.pop('lick_trace_sampling_rate')
        for b_k, b_v in behavior_data.items():
            behav_acq.create_timeseries(name=b_k,
                                        unit='a.u.',
                                        conversion=1.0,
                                        data=b_v,
                                        starting_time=lt_start_time,
                                        rate=lt_fs)

    # =============== Photostimulation ====================
    photostim = ((stimulation.PhotoStimulation
                  & session_key).fetch1() if stimulation.PhotoStimulation
                 & session_key else None)
    if photostim:
        photostim_device = (stimulation.PhotoStimDevice & photostim).fetch1()
        stim_device = nwbfile.create_device(
            name=photostim_device['device_name'])
        stim_site = pynwb.ogen.OptogeneticStimulusSite(
            name='-'.join([photostim['hemisphere'],
                           photostim['brain_region']]),
            device=stim_device,
            excitation_lambda=float(photostim['photo_stim_excitation_lambda']),
            location='; '.join([
                f'{k}: {str(v)}' for k, v in (reference.ActionLocation
                                              & photostim).fetch1().items()
            ]),
            description=(stimulation.PhotoStimulationInfo
                         & photostim).fetch1('photo_stim_notes'))
        nwbfile.add_ogen_site(stim_site)

        if photostim['photostim_timeseries'] is not None:
            nwbfile.add_stimulus(
                pynwb.ogen.OptogeneticSeries(
                    name='_'.join([
                        'photostim_on',
                        photostim['photostim_datetime'].strftime(
                            '%Y-%m-%d_%H-%M-%S')
                    ]),
                    site=stim_site,
                    resolution=0.0,
                    conversion=1e-3,
                    data=photostim['photostim_timeseries'],
                    starting_time=photostim['photostim_start_time'],
                    rate=photostim['photostim_sampling_rate']))

    # =============== TrialSet ====================
    # NWB 'trial' (of type dynamic table) by default comes with three mandatory attributes:
    #                                                                       'id', 'start_time' and 'stop_time'.
    # Other trial-related information needs to be added in to the trial-table as additional columns (with column name
    # and column description)
    if acquisition.TrialSet & session_key:
        # Get trial descriptors from TrialSet.Trial and TrialStimInfo
        trial_columns = [
            {
                'name':
                tag.replace('trial_', ''),
                'description':
                re.search(
                    f'(?<={tag})(.*)#(.*)',
                    str((acquisition.TrialSet.Trial *
                         stimulation.TrialPhotoStimInfo
                         ).heading)).groups()[-1].strip()
            } for tag in (acquisition.TrialSet.Trial *
                          stimulation.TrialPhotoStimInfo).heading.names
            if tag not in (acquisition.TrialSet.Trial
                           & stimulation.TrialPhotoStimInfo).primary_key +
            ['start_time', 'stop_time']
        ]

        # Trial Events - discard 'trial_start' and 'trial_stop' as we already have start_time and stop_time
        # also add `_time` suffix to all events
        trial_events = set(((acquisition.TrialSet.EventTime & session_key) -
                            [{
                                'trial_event': 'trial_start'
                            }, {
                                'trial_event': 'trial_stop'
                            }]).fetch('trial_event'))
        event_names = [{
            'name': e + '_time',
            'description': d
        } for e, d in zip(*(reference.ExperimentalEvent & [{
            'event': k
        } for k in trial_events]).fetch('event', 'description'))]
        # Add new table columns to nwb trial-table for trial-label
        for c in trial_columns + event_names:
            nwbfile.add_trial_column(**c)

        photostim_tag_default = {
            tag: ''
            for tag in stimulation.TrialPhotoStimInfo.heading.names
            if tag not in stimulation.TrialPhotoStimInfo.primary_key
        }

        # Add entry to the trial-table
        for trial in (acquisition.TrialSet.Trial
                      & session_key).fetch(as_dict=True):
            events = dict(
                zip(*(acquisition.TrialSet.EventTime & trial
                      & [{
                          'trial_event': e
                      } for e in trial_events]
                      ).fetch('trial_event', 'event_time')))

            trial_tag_value = ({
                **trial,
                **events,
                **(stimulation.TrialPhotoStimInfo & trial).fetch1()
            } if (stimulation.TrialPhotoStimInfo & trial) else {
                **trial,
                **events,
                **photostim_tag_default
            })

            # rename 'trial_id' to 'id'
            trial_tag_value['id'] = trial_tag_value['trial_id']
            [
                trial_tag_value.pop(k)
                for k in acquisition.TrialSet.Trial.primary_key
            ]

            # Final tweaks: i) add '_time' suffix and ii) remove 'trial_' prefix
            events = {k + '_time': trial_tag_value.pop(k) for k in events}
            trial_attrs = {
                k.replace('trial_', ''): trial_tag_value.pop(k)
                for k in
                [n for n in trial_tag_value if n.startswith('trial_')]
            }

            nwbfile.add_trial(**trial_tag_value, **events, **trial_attrs)

    # =============== Write NWB 2.0 file ===============
    if save:
        save_file_name = ''.join([nwbfile.identifier, '.nwb'])
        if not os.path.exists(nwb_output_dir):
            os.makedirs(nwb_output_dir)
        if not overwrite and os.path.exists(
                os.path.join(nwb_output_dir, save_file_name)):
            return nwbfile
        with NWBHDF5IO(os.path.join(nwb_output_dir, save_file_name),
                       mode='w') as io:
            io.write(nwbfile)
            print(f'Write NWB 2.0 file: {save_file_name}')

    return nwbfile
def export_to_nwb(session_key, nwb_output_dir=default_nwb_output_dir, save=False, overwrite=False):
    this_session = (acquisition.Session & session_key).fetch1()

    identifier = '_'.join([this_session['subject_id'],
                           this_session['session_time'].strftime('%Y-%m-%d'),
                           this_session['session_id']])

    # =============== General ====================
    # -- NWB file - a NWB2.0 file for each session
    nwbfile = NWBFile(
        session_description=this_session['session_note'],
        identifier=identifier,
        session_start_time=datetime.combine(this_session['session_time'], zero_zero_time),
        file_create_date=datetime.now(tzlocal()),
        experimenter='; '.join((acquisition.Session.Experimenter & session_key).fetch('experimenter')),
        institution=institution,
        experiment_description=experiment_description,
        related_publications=related_publications,
        keywords=keywords)
    # -- subject
    subj = (subject.Subject & session_key).fetch1()
    nwbfile.subject = pynwb.file.Subject(
        subject_id=this_session['subject_id'],
        description=subj['subject_description'],
        genotype=' x '.join((subject.Subject.Allele & session_key).fetch('allele')),
        sex=subj['sex'],
        species=subj['species'])

    # =============== Intracellular ====================
    cell = ((intracellular.Cell & session_key).fetch1()
            if intracellular.Cell & session_key
            else None)
    if cell:
        # metadata
        whole_cell_device = nwbfile.create_device(name=cell['device_name'])
        ic_electrode = nwbfile.create_ic_electrode(
            name=cell['session_id'],
            device=whole_cell_device,
            description='N/A',
            filtering='N/A',
            location='; '.join([f'{k}: {str(v)}'
                                for k, v in (reference.ActionLocation & cell).fetch1().items()]))
        # acquisition - membrane potential
        mp, mp_timestamps = (intracellular.MembranePotential & cell).fetch1(
            'membrane_potential', 'membrane_potential_timestamps')
        nwbfile.add_acquisition(pynwb.icephys.PatchClampSeries(name='PatchClampSeries',
                                                               electrode=ic_electrode,
                                                               unit='mV',
                                                               conversion=1.0,
                                                               gain=1.0,
                                                               data=mp,
                                                               timestamps=mp_timestamps))
        # acquisition - current injection
        if (intracellular.CurrentInjection & cell):
            current_injection, ci_timestamps = (intracellular.CurrentInjection & cell).fetch1(
                'current_injection', 'current_injection_timestamps')
            nwbfile.add_stimulus(pynwb.icephys.CurrentClampStimulusSeries(name='CurrentClampStimulus',
                                                                          electrode=ic_electrode,
                                                                          conversion=1e-6,
                                                                          gain=1.0,
                                                                          data=current_injection,
                                                                          timestamps=ci_timestamps))

    # =============== Extracellular ====================
    probe_insertion = ((extracellular.ProbeInsertion & session_key).fetch1()
                       if extracellular.ProbeInsertion & session_key
                       else None)
    if probe_insertion:
        probe = nwbfile.create_device(name=probe_insertion['probe_name'])
        electrode_group = nwbfile.create_electrode_group(
            name='; '.join([f'{probe_insertion["probe_name"]}: {str(probe_insertion["channel_counts"])}']),
            description = 'N/A',
            device = probe,
            location = '; '.join([f'{k}: {str(v)}' for k, v in
                                  (reference.BrainLocation & (extracellular.ProbeInsertion.InsertLocation
                                   & probe_insertion)).fetch1().items()]))

        for chn in (reference.Probe.Channel & probe_insertion).fetch(as_dict=True):
            nwbfile.add_electrode(id=chn['channel_id'],
                                  group=electrode_group,
                                  filtering=hardware_filter,
                                  imp=-1.,
                                  x=chn['channel_x_pos'],
                                  y=chn['channel_y_pos'],
                                  z=chn['channel_z_pos'],
                                  location=electrode_group.location)

        # --- unit spike times ---
        nwbfile.add_unit_column(name='sampling_rate', description='Sampling rate of the raw voltage traces (Hz)')
        nwbfile.add_unit_column(name='cell_desc', description='description of this unit (e.g. cell type)')

        for unit in (extracellular.UnitSpikeTimes & probe_insertion).fetch(as_dict=True):
            # waveforms - mean and std over spikes per electrode
            wfs = (extracellular.UnitSpikeTimes.SpikeWaveform & unit).fetch('spike_waveform')
            wfs_mean = np.vstack([wf.mean(axis=0) for wf in wfs]).T
            wfs_sd = np.vstack([wf.std(axis=0) for wf in wfs]).T

            # make an electrode table region (which electrode(s) is this unit coming from)
            nwbfile.add_unit(id=unit['unit_id'],
                             electrodes=(extracellular.UnitSpikeTimes.UnitChannel & unit).fetch('channel_id') - 1,
                             sampling_rate=ecephys_fs,
                             cell_desc=unit['cell_desc'],
                             spike_times=unit['spike_times'],
                             waveform_mean=wfs_mean,
                             waveform_sd=wfs_sd)

    # =============== Behavior ====================
    lick_trace_data = ((behavior.LickTrace & session_key).fetch1()
                       if behavior.LickTrace & session_key
                       else None)
    if lick_trace_data:
        behav_acq = pynwb.behavior.BehavioralTimeSeries(name='lick_trace')
        nwbfile.add_acquisition(behav_acq)
        [lick_trace_data.pop(k) for k in behavior.LickTrace.primary_key]
        timestamps = lick_trace_data.pop('lick_trace_timestamps')
        for b_k, b_v in lick_trace_data.items():
            behav_acq.create_timeseries(name=b_k,
                                        unit='a.u.',
                                        conversion=1.0,
                                        data=b_v,
                                        timestamps=timestamps)

    if behavior.Whisker & session_key:
        for whisker_data in (behavior.Whisker & session_key).fetch(as_dict=True):
            behav_acq = pynwb.behavior.BehavioralTimeSeries(name='principal_'*whisker_data.pop('principal_whisker') + 'whisker_' + whisker_data['whisker_config'])
            nwbfile.add_acquisition(behav_acq)
            [whisker_data.pop(k) for k in behavior.Whisker.primary_key]
            timestamps = whisker_data.pop('behavior_timestamps')

            # get behavior data description from the comments of table definition
            behavior_descriptions = {attr: re.search(f'(?<={attr})(.*)#(.*)',
                                                     str(behavior.Whisker.heading)).groups()[-1].strip()
                                     for attr in whisker_data}

            for b_k, b_v in whisker_data.items():
                behav_acq.create_timeseries(name=b_k,
                                            description=behavior_descriptions[b_k],
                                            unit='a.u.',
                                            conversion=1.0,
                                            data=b_v,
                                            timestamps=timestamps)

    # =============== Photostimulation ====================
    if stimulation.PhotoStimulation & session_key:
        for photostim in (stimulation.PhotoStimulation & session_key).fetch(as_dict=True):
            photostim_device = (stimulation.PhotoStimDevice & photostim).fetch1()
            stim_device = (nwbfile.devices.get(photostim_device['device_name'])
                           if photostim_device['device_name'] in nwbfile.devices.keys()
                           else nwbfile.create_device(name=photostim_device['device_name']))
            stim_site = pynwb.ogen.OptogeneticStimulusSite(
                name=photostim['photostim_id'],
                device=stim_device,
                excitation_lambda=float((stimulation.PhotoStimProtocol & photostim).fetch1('photo_stim_excitation_lambda')),
                location = '; '.join([f'{k}: {str(v)}' for k, v in
                                      (reference.ActionLocation & photostim).fetch1().items()]),
                description=(stimulation.PhotoStimProtocol & photostim).fetch1('photo_stim_notes'))
            nwbfile.add_ogen_site(stim_site)

            if photostim['photostim_timeseries'] is not None:
                nwbfile.add_stimulus(pynwb.ogen.OptogeneticSeries(
                    name='photostimulation_' + photostim['photostim_id'],
                    site=stim_site,
                    resolution=0.0,
                    conversion=1e-3,
                    data=photostim['photostim_timeseries'],
                    timestamps=photostim['photostim_timestamps']))

    # =============== TrialSet ====================
    # NWB 'trial' (of type dynamic table) by default comes with three mandatory attributes:
    #                                                                       'id', 'start_time' and 'stop_time'.
    # Other trial-related information needs to be added in to the trial-table as additional columns (with column name
    # and column description)
    if acquisition.TrialSet & session_key:
        # Get trial descriptors from TrialSet.Trial and TrialStimInfo
        trial_columns = [{'name': tag.replace('trial_', ''),
                          'description': re.search(
                              f'(?<={tag})(.*)#(.*)',
                              str((acquisition.TrialSet.Trial
                                   * stimulation.TrialPhotoStimParam).heading)).groups()[-1].strip()}
                         for tag in (acquisition.TrialSet.Trial * stimulation.TrialPhotoStimParam).heading.names
                         if tag not in (acquisition.TrialSet.Trial
                                        & stimulation.TrialPhotoStimParam).primary_key + ['start_time', 'stop_time']]

        # Trial Events - discard 'trial_start' and 'trial_stop' as we already have start_time and stop_time
        # also add `_time` suffix to all events
        trial_events = set(((acquisition.TrialSet.EventTime & session_key)
                            - [{'trial_event': 'trial_start'}, {'trial_event': 'trial_stop'}]).fetch('trial_event'))
        event_names = [{'name': e + '_time', 'description': d}
                       for e, d in zip(*(reference.ExperimentalEvent & [{'event': k}
                                                                        for k in trial_events]).fetch('event',
                                                                                                      'description'))]
        # Add new table columns to nwb trial-table for trial-label
        for c in trial_columns + event_names:
            nwbfile.add_trial_column(**c)

        photostim_tag_default = {tag: '' for tag in stimulation.TrialPhotoStimParam.heading.names
                                 if tag not in stimulation.TrialPhotoStimParam.primary_key}

        # Add entry to the trial-table
        for trial in (acquisition.TrialSet.Trial & session_key).fetch(as_dict=True):
            events = dict(zip(*(acquisition.TrialSet.EventTime & trial
                                & [{'trial_event': e} for e in trial_events]).fetch('trial_event', 'event_time')))
            # shift event times to be relative to session_start (currently relative to trial_start)
            events = {k: v + trial['start_time'] for k, v in events.items()}

            trial_tag_value = ({**trial, **events, **(stimulation.TrialPhotoStimParam & trial).fetch1()}
                               if (stimulation.TrialPhotoStimParam & trial)
                               else {**trial, **events, **photostim_tag_default})

            trial_tag_value['id'] = trial_tag_value['trial_id']  # rename 'trial_id' to 'id'
            # convert None to np.nan since nwb fields does not take None
            for k, v in trial_tag_value.items():
                trial_tag_value[k] = v if v is not None else np.nan
            [trial_tag_value.pop(k) for k in acquisition.TrialSet.Trial.primary_key]

            # Final tweaks: i) add '_time' suffix and ii) remove 'trial_' prefix
            events = {k + '_time': trial_tag_value.pop(k) for k in events}
            trial_attrs = {k.replace('trial_', ''): trial_tag_value.pop(k)
                           for k in [n for n in trial_tag_value if n.startswith('trial_')]}

            nwbfile.add_trial(**trial_tag_value, **events, **trial_attrs)

    # =============== Write NWB 2.0 file ===============
    if save:
        save_file_name = ''.join([nwbfile.identifier, '.nwb'])
        if not os.path.exists(nwb_output_dir):
            os.makedirs(nwb_output_dir)
        if not overwrite and os.path.exists(os.path.join(nwb_output_dir, save_file_name)):
            return nwbfile
        with NWBHDF5IO(os.path.join(nwb_output_dir, save_file_name), mode='w') as io:
            io.write(nwbfile)
            print(f'Write NWB 2.0 file: {save_file_name}')

    return nwbfile