Пример #1
0
    def populate(self, *restrictions, level='New', display_progress=False):

        # populate new jobs only
        if level == 'New':
            cond = {}
        # populate new jobs and error jobs only
        elif level == 'Error':
            cond = 'job_status in ("Success", "Partial Success")'
        # populate new jobs, partial success jobs and error jobs
        elif level == 'Partial':
            cond = 'job_status in ("Success")'
        elif level == 'All':
            cond = []

        self.key_source = (Session - (self & cond)) & dj.AndList(restrictions)
        keys = self.key_source.fetch('KEY')

        for key in (tqdm(keys, position=0) if display_progress else keys):
            self.make(key)
Пример #2
0
 def restriction(self):
     # first element includes the spec's restriction,
     # second element includes the restriction from query parameters
     return dj.AndList(
         [
             self.dj_restriction(),
             {
                 k: (
                     datetime.fromtimestamp(float(v)).isoformat()
                     if re.match(
                         r"^date.*$",
                         self.fetch_metadata["query"].heading.attributes[k].type,
                     )
                     else v
                 )
                 for k, v in request.args.items()
                 if k in self.fetch_metadata["query"].heading.attributes
             },
         ]
     )
Пример #3
0
    def _combine_transfer_recipes(self, transfer_step):
        """
        Combines multiple transfer recipes and their restrictions as specified by post_restr attribute.
        The combination is transfer-step-specific, meaning only the recipes relevant for a specific transfer step would be combined.

        Combining recipes is simple and the user does not need to interact with this method directly. Below is an example:
        Let us assume you have two recipe tables: TrainerRecipe and ModelRecipe, then you can attach all these recipes to your
        TransferTrainedModel table as follows:

        ``` Python
            TransferTrainedModel.transfer_recipe = [TrainerRecipe, ModelRecipe]
        ```

        The rest (combining the recipes and their restrictions) is taken care of by this method.

        Args:
            transfer_step (int): table population transfer step.

        Returns:
            string or datajoint AndList: A single or combined restriction of one or multiple recipes, respectively.
        """

        if not isinstance(self.transfer_recipe, Sequence):
            return self.transfer_recipe
        # else: get the recipes that have an entry for a specific transfer step
        transfer_recipe = []
        for tr in self.transfer_recipe:
            # check if an entry exists for a specific transfer step in the recipe
            if tr & f"transfer_step = {transfer_step}":
                # if it exists add that entry to the list of recipes (relevant for a specific transfer step)
                transfer_recipe.append(tr & f"transfer_step = {transfer_step}")
        if not transfer_recipe:
            return self.proj() - self  # return something empty
        # join all the recipes (and their post_restr)
        joined = transfer_recipe[0]
        if len(transfer_recipe) > 1:
            for t in transfer_recipe[1:]:
                joined *= t  # all combination of recipes
            joined.post_restr = dj.AndList(
                [recipe.post_restr for recipe in self.transfer_recipe])
        return joined
Пример #4
0
    def _record_dependency(jwt_payload: dict,
                           schema_name: str,
                           table_name: str,
                           restriction: list = []) -> list:
        """
        Return summary of dependencies associated with a restricted table. Will only show
        dependencies that user has access to.

        :param jwt_payload: Dictionary containing databaseAddress, username, and password
            strings
        :type jwt_payload: dict
        :param schema_name: Name of schema
        :type schema_name: str
        :param table_name: Table name under the given schema; must be in camel case
        :type table_name: str
        :param restriction: Sequence of filters as ``dict`` with ``attributeName``,
            ``operation``, ``value`` keys defined, defaults to ``[]``
        :type restriction: list
        :return: Tables that are dependent on specific records.
        :rtype: list
        """
        _DJConnector._set_datajoint_config(jwt_payload)
        virtual_module = dj.VirtualModule(schema_name, schema_name)
        table = getattr(virtual_module, table_name)
        attributes = table.heading.attributes
        # Retrieve dependencies of related to retricted
        dependencies = [
            dict(
                schema=descendant.database,
                table=descendant.table_name,
                accessible=True,
                count=len((table if descendant.full_table_name ==
                           table.full_table_name else descendant * table)
                          & dj.AndList([
                              _DJConnector._filter_to_restriction(
                                  f, attributes[f["attributeName"]].type)
                              for f in restriction
                          ])),
            ) for descendant in table().descendants(as_objects=True)
        ]
        return dependencies
Пример #5
0
def compare_ni_and_bpod_times(q_sess=dj.AndList(['subject_id = "473361"', 'session >= 57']), event_to_align='bitcodestart'):
    '''
    Compare NI-TIME and BPOD-TIME
    This is a critical validation for ephys timing alignment
    
    [Conclusions]:
    1. Even Bpod-time is not accurate... (occationally up to 3ms jitter)
    2. At least for choice, NIXD is very close to groundtruth NIXA (< 0.5 ms)
    3. We should always trust NI device and feed everything to it!
    4. NOTE the float number precision problem: should use double or at least Decimal(8,4) for times! 

    '''
    
    event_all = ['bitcodestart', 'go', 'choice', 'trialend']
    event_to_compare = [e for e in event_all if e != event_to_align]
    
    # 1. -- all events, Bpod vs NIXD --
    # Ephys time
    ephys_times = (ephys.TrialEvent & q_sess & f'trial_event_type IN {tuple(event_all)}').fetch(format='frame')
    ephys_times = ephys_times.reset_index().pivot(index = ['subject_id', 'session', 'trial'], columns='trial_event_type').trial_event_time.astype(float)
    session_times = ephys_times.bitcodestart
    ephys_times = ephys_times.sub(ephys_times[event_to_align], axis=0).drop(columns=event_to_align)  # To each bitcode start (bpodstart is sometimes problematic if a new bpod session is started)

    # Bpod time
    bpod_times = (experiment.TrialEvent & q_sess & f'trial_event_type IN {tuple(event_all)}' ).fetch(format='frame')
    bpod_times = bpod_times.reset_index().pivot(index = ['subject_id', 'session', 'trial'], columns='trial_event_type').trial_event_time.astype(float)  # Already related to bpod start
    bpod_times = bpod_times.sub(bpod_times[event_to_align], axis=0).drop(columns=event_to_align)  # To each trial's bpod start

    # Plot: Bpod vs NIXD, distribution of differences
    fig = plt.figure(figsize=(8,13))
    ax = fig.subplots(len(event_to_compare), 2)
    max_error = 0
    for n, event in enumerate(event_to_compare):
        ax[n, 0].plot(ephys_times[event], bpod_times[event], '*', label=event)
        ax[n, 1].hist((np.array(bpod_times[event]) - np.array(ephys_times[event])) * 1000, bins=100, label=event)
        ax_max = max(max(ax[n, 0].get_xlim()), max(ax[n, 0].get_ylim()))
        ax_min = min(min(ax[n, 0].get_xlim()), min(ax[n, 0].get_ylim()))
        ax[n, 0].plot([ax_min, ax_max], [ax_min, ax_max], 'k:')
        ax[n, 0].legend()
        max_error = max(max_error, max(np.abs(ax[n, 1].get_xlim())))

    for _a in ax[:, 1]:
        _a.set_xlim(-max_error, max_error)

    ax[-1, 0].set_xlabel(f'NI time to {event_to_align} (s)')
    ax[-1, 0].set_ylabel('Bpod time (s)')
    ax[-1, 1].set_xlabel('Bpod time - NI time (ms)')
    ax[0, 1].set_title(f'{q_sess}\ntotal {len(bpod_times)} trials', fontsize=8)


    # 2. -- first lick: NIXA (raw lickboard, close to ground truth?), NIXD (bpod digital to NI), and Bpod (bpod csv file) --
    ephys_go_cue = (ephys.TrialEvent & 'trial_event_type = "go"' & q_sess).proj(ephys_go='trial_event_time')
    nixa_first_lick = ephys_go_cue.proj().aggr(
        (ephys.ActionEvent * ephys_go_cue) & 'action_event_time >= ephys_go',
        nixa='min(action_event_time)')    # Session-time of first lick of each trial

    q_all = (nixa_first_lick.proj(..., tmp='trial_event_id')   # NIXA
            * (ephys.TrialEvent & 'trial_event_type = "choice"').proj(nixd='trial_event_time', tmp1='trial_event_id')   # NIXD
            * (ephys.TrialEvent & f'trial_event_type = "{event_to_align}"').proj(ni_align='trial_event_time')
            * (experiment.TrialEvent & 'trial_event_type = "choice"').proj(bpod='trial_event_time', tmp2='trial_event_id')  # Bpod csv
            * (experiment.TrialEvent & f'trial_event_type = "{event_to_align}"').proj(bpod_align='trial_event_time', tmp3='trial_event_id')
            & q_sess).proj(...,
                           nixa_aligned='nixa - ni_align',
                           nixd_aligned='nixd - ni_align',
                           bpod_aligned='bpod - bpod_align')

    diff_cut_off = 30   # (ms) Cutoff where discrepancy is due to lick faster than 10 ms go cue bitcode (has been decrease to 1ms)
    nixa, nixd, bpod, ni_align = (q_all & f'abs(nixa_aligned - nixd_aligned) < {diff_cut_off/1000}').fetch('nixa_aligned', 'nixd_aligned', 'bpod_aligned', 'ni_align')
    n_cut_off = len(q_all) - len(nixa)

    nixa = nixa.astype(float) * 1000
    nixd = nixd.astype(float) * 1000
    bpod = bpod.astype(float) * 1000

    # (Bpod - NIXA) vs (NIXD - NIXA)
    fig = plt.figure(figsize=(8,8))
    ax = fig.subplots(2, 2)
    ax[0,0].plot(bpod - nixa, nixd - nixa, '.')
    ax[0,0].set(xlabel='bpod - nixa (ms)', ylabel='nixd - nixa')
    ax[0,1].hist(bpod - nixa, bins = 100)
    ax[0,1].set(xlabel='bpod - nixa (ms)', title=f'first lick after go, total_trial = {len(bpod)}, cut_off = {n_cut_off}')
    ax[1,0].hist(nixd - nixa, bins = 100)
    ax[1,0].set(xlabel='nixd - nixa (ms)')
    ax[1,1].hist(bpod - nixd, bins = 100)
    ax[1,1].set(xlabel='bpod - nixd (ms)')

    # 3. -- Error as a function of session time --
    # For the up-to-10-ms unacceptable jitter between times,
    # I have narrowed it down to the precision of computing float numbers!
    # Here I just compare choice times from ephys_times and q_all, which are
    # using exactly the same data (ephys.TrialEvent), but by different approaches:
    #    ephys_times: fetch and subtract
    #    q_all: subtract and fetch
    non_nan = ~np.isnan(ephys_times.choice.astype(float))
    session_times = session_times[non_nan]
    fetch_and_subtract = ephys_times.choice[non_nan]
    subtract_and_fetch = q_all.fetch('nixd_aligned').astype(float)
    ax = plt.figure(figsize=(10,10)).subplots(3,1)
    ax[0].plot(session_times, (fetch_and_subtract - subtract_and_fetch) * 1000, '.')
    ax[0].set(title='NIXD choice, fetch first - subtract first')
    
    # Choice time, Bpod vs NIXA
    ax[1].plot(ni_align, bpod - nixa, '.')
    ax[1].set(title='choice (Bpod - NIXA)', ylabel='time error (ms)')
    
    # Choice time, NIXD vs NIXA
    ax[2].plot(ni_align, nixd - nixa, '.')
    ax[2].set(title='choice (NIXD - NIXA)', xlabel='time in session (s)')
    
    for _a in ax.flat:
        _a.label_outer()
Пример #6
0
def compare_pc_and_bpod_times(q_sess=dj.AndList(['water_restriction_number = "HH09"', 'session < 10'])):
    '''
    Compare PC-TIME and BPOD-TIME of pybpod csv file
    This is a critical validation for ephys timing alignment
    This function requires raw .csv data to be present under `behavior_bpod`.`project_paths`    
    
    [Conclusions]:
    1. PC-TIME has an average delay of 4 ms and sometimes much much longer!
    2. We should always use BPOD-TIME (at least for all offline analysis)

    Parameters
    ----------
    q_sess : TYPE, optional
        DESCRIPTION. Query of session. The default is dj.AndList(['water_restriction_number = "HH09"', 'session < 10']).

    Returns
    -------
    None.

    '''
    csv_folders = (behavior_ingest.BehaviorBpodIngest.BehaviorFile * lab.WaterRestriction & q_sess
                 ).fetch('behavior_file')
    
    current_project_path = dj.config['custom']['behavior_bpod']['project_paths'][0]
    current_ingestion_folder = re.findall('(.*)Behavior_rigs.*', current_project_path)[0]
    
    event_to_compare = ['GoCue', 'Choice_L', 'Choice_R', 'ITI', 'End']
    event_pc_times = {event: list() for event in event_to_compare}
    event_bpod_times = {event: list() for event in event_to_compare}
    
    # --- Loop over all files
    for csv_folder in tqdm(csv_folders):
        csv_file = (current_ingestion_folder + re.findall('.*(Behavior_rigs.*)', csv_folder)[0]
                    + '/'+ csv_folder.split('/')[-1] + '.csv')
        df_behavior_session = load_and_parse_a_csv_file(csv_file)
        
        # ---- Integrity check of the current bpodsess file ---
        # It must have at least one 'trial start' and 'trial end'
        trial_start_idxs = df_behavior_session[(df_behavior_session['TYPE'] == 'TRIAL') & (
                    df_behavior_session['MSG'] == 'New trial')].index
        if not len(trial_start_idxs):
            continue   # Make sure 'start' exists, otherwise move on to try the next bpodsess file if exists     
            
        # For each trial
        trial_end_idxs = trial_start_idxs[1:].append(pd.Index([(max(df_behavior_session.index))]))        
        
        for trial_start_idx, trial_end_idx in zip(trial_start_idxs, trial_end_idxs):
            df_behavior_trial = df_behavior_session[trial_start_idx:trial_end_idx + 1]
            
            pc_time_trial_start = df_behavior_trial.iloc[0]['PC-TIME']
            
            for event in event_to_compare:
                idx = df_behavior_trial[(df_behavior_trial['TYPE'] == 'TRANSITION') & (
                                df_behavior_trial['MSG'] == event)].index
                if not len(idx):
                    continue
                
                # PC-TIME
                pc_time = (df_behavior_trial.loc[idx]['PC-TIME']- pc_time_trial_start).values / np.timedelta64(1, 's')
                event_pc_times[event].append(pc_time[0])
                
                # BPOD-TIME
                bpod_time = df_behavior_trial.loc[idx]['BPOD-INITIAL-TIME'].values
                event_bpod_times[event].append(bpod_time[0])
        
    # --- Plotting ---
    fig = plt.figure()
    ax = fig.subplots(1,2)
    for event in event_pc_times:
        ax[0].plot(event_bpod_times[event], event_pc_times[event], '*', label=event)
        ax[1].hist((np.array(event_pc_times[event]) - np.array(event_bpod_times[event])) * 1000, range=(0,20), bins=500, label=event)
    
    ax_max = max(max(ax[0].get_xlim()), max(ax[0].get_ylim()))
    ax[0].plot([0, ax_max], [0, ax_max], 'k:')    
    ax[0].set_xlabel('Bpod time from trial start (s)')
    ax[0].set_ylabel('PC time (s)')
    ax[1].set_xlabel('PC time lag (ms)')
    ax[1].set_title(f'{q_sess}\ntotal {len(csv_folders)} bpod csv files')
    ax[0].legend()
Пример #7
0
    def train_evaluate(self, auto_params):
        """
        For a given set of parameters, add an entry to the corresponding tables, and populated the trained model
        table for that specific entry.

        Args:
            auto_params (dict): dictionary of dictionaries where each dictionary specifies a single parameter to be optimized.

        Returns:
            float: the score of the trained model for the specific entry in trained model table
        """
        config = self._combine_params(self._split_config(auto_params),
                                      self.fixed_params)

        # insert the stuff into their corresponding tables
        dataset_hash = make_hash(config["dataset"])
        entry_exists = {
            "dataset_fn": "{}".format(self.fns["dataset"])
        } in self.trained_model_table.dataset_table() and {
            "dataset_hash": "{}".format(dataset_hash)
        } in self.trained_model_table.dataset_table()
        if not entry_exists:
            self.trained_model_table.dataset_table().add_entry(
                self.fns["dataset"],
                config["dataset"],
                dataset_fabrikant=self.architect,
                dataset_comment=self.comment,
            )

        model_hash = make_hash(config["model"])
        entry_exists = {
            "model_fn": "{}".format(self.fns["model"])
        } in self.trained_model_table.model_table() and {
            "model_hash": "{}".format(model_hash)
        } in self.trained_model_table.model_table()
        if not entry_exists:
            self.trained_model_table.model_table().add_entry(
                self.fns["model"],
                config["model"],
                model_fabrikant=self.architect,
                model_comment=self.comment,
            )

        trainer_hash = make_hash(config["trainer"])
        entry_exists = {
            "trainer_fn": "{}".format(self.fns["trainer"])
        } in self.trained_model_table.trainer_table() and {
            "trainer_hash": "{}".format(trainer_hash)
        } in self.trained_model_table.trainer_table()
        if not entry_exists:
            self.trained_model_table.trainer_table().add_entry(
                self.fns["trainer"],
                config["trainer"],
                trainer_fabrikant=self.architect,
                trainer_comment=self.comment,
            )

        # get the primary key values for all those entries
        restriction = (
            'dataset_fn in ("{}")'.format(self.fns["dataset"]),
            'dataset_hash in ("{}")'.format(dataset_hash),
            'model_fn in ("{}")'.format(self.fns["model"]),
            'model_hash in ("{}")'.format(model_hash),
            'trainer_fn in ("{}")'.format(self.fns["trainer"]),
            'trainer_hash in ("{}")'.format(trainer_hash),
        )

        # populate the table for those primary keys
        self.trained_model_table().populate(*restriction)

        # get the score of the model for this specific set of hyperparameters
        score = (self.trained_model_table()
                 & dj.AndList(restriction)).fetch("score")[0]

        return score
Пример #8
0
class AlignedTrialSpikes(dj.Computed):
    definition = """
    # Spike times of each trial aligned to different events
    -> DefaultCluster
    -> behavior.TrialSet.Trial
    -> Event
    ---
    trial_spike_times=null:   longblob     # spike time for each trial, aligned to different event times
    trial_spikes_ts=CURRENT_TIMESTAMP:    timestamp
    """
    key_source = behavior.TrialSet * DefaultCluster * Event & \
        ['event in ("stim on", "feedback")',
         dj.AndList([wheel.MovementTimes, 'event="movement"'])]

    def make(self, key):

        cluster = DefaultCluster() & key
        spike_times = cluster.fetch1('cluster_spikes_times')
        event = (Event & key).fetch1('event')

        if event == 'movement':
            trials = behavior.TrialSet.Trial * wheel.MovementTimes & key
            trial_keys, trial_start_times, trial_end_times, \
                trial_stim_on_times, trial_feedback_times, \
                trial_movement_times = \
                trials.fetch('KEY', 'trial_start_time', 'trial_end_time',
                             'trial_stim_on_time', 'trial_feedback_time',
                             'movement_onset')
        else:
            trials = behavior.TrialSet.Trial & key
            trial_keys, trial_start_times, trial_end_times, \
                trial_stim_on_times, trial_feedback_times = \
                trials.fetch('KEY', 'trial_start_time', 'trial_end_time',
                             'trial_stim_on_time', 'trial_feedback_time')

        # trial idx of each spike
        spike_ids = np.searchsorted(
            np.sort(
                np.hstack(np.vstack([trial_start_times, trial_end_times]).T)),
            spike_times)

        trial_spks = []
        for itrial, trial_key in enumerate(trial_keys):

            trial_spk = dict(**trial_key,
                             cluster_id=key['cluster_id'],
                             probe_idx=key['probe_idx'])

            trial_spike_time = spike_times[spike_ids == itrial * 2 + 1]

            if not len(trial_spike_time):
                trial_spk['trial_spike_times'] = np.array([])
            else:
                if event == 'stim on':
                    trial_spk['trial_spike_times'] = \
                        trial_spike_time - trial_stim_on_times[itrial]
                elif event == 'movement':
                    trial_spk['trial_spike_times'] = \
                        trial_spike_time - trial_movement_times[itrial]
                elif event == 'feedback':
                    if trial_feedback_times[itrial]:
                        trial_spk['trial_spike_times'] = \
                            trial_spike_time - trial_feedback_times[itrial]
                    else:
                        continue

            trial_spk['event'] = event
            trial_spks.append(trial_spk.copy())

        self.insert(trial_spks)
Пример #9
0
import datajoint as dj
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import stats

from locker import analysis as alys, colors, colordict
from locker import data

alys.FirstOrderSignificantPeaks() * data.Cells() & 'eod_coeff = 0 and abs(stimulus_coeff) = 1 and baseline_coeff=0 and refined=1'

restr = dj.AndList([
    ['cell_type="e-cell"', 'cell_type = "i-cell"'],
        "n_harmonics = 0",
        "am = 0",
        "contrast = 20.0",
         'refined=1',
        '(abs(delta_f) between 95 and 105)',
        'eod_coeff = 0 and abs(stimulus_coeff) = 1 and baseline_coeff=0'
]
)
temp = alys.FirstOrderSignificantPeaks() * data.Runs() * data.Cells() & restr
temp = temp.proj('vector_strength', 'frequency','cell_type', pos='delta_f > 0')

pd_pyr = pd.DataFrame(temp.fetch())

g = sns.catplot("cell_type", "vector_strength", hue="pos", data=pd_pyr, kind="bar")
g.set_ylabels('Cell Type')
g.set_xlabels('Vector Strength')
sns.despine()

fig = plt.gcf()
Пример #10
0
class DepthPeth(dj.Computed):
    definition = """
    -> ephys.ProbeInsertion
    -> ephys.Event
    -> TrialType
    ---
    depth_peth          : blob@ephys   # firing rate for each depth bin and time bin
    depth_bin_centers   : longblob     # centers of the depth bins
    time_bin_centers    : longblob     # centers of the time bin
    depth_baseline      : longblob     # baseline for each depth bin, average activity during -0.3 to 0 relative to the event
    depth_peth_ts=CURRENT_TIMESTAMP : timestamp
    """
    key_source = ephys.ProbeInsertion * ephys.Event * \
        (TrialType & 'trial_type="Correct All"') & ephys.DefaultCluster & \
        behavior.TrialSet & \
        ['event in ("stim on", "feedback")',
         dj.AndList([wheel.MovementTimes, 'event="movement"'])]

    def make(self, key):

        clusters_spk_depths, clusters_spk_times, clusters_ids = \
            (ephys.DefaultCluster & key).fetch(
                'cluster_spikes_depths', 'cluster_spikes_times', 'cluster_id')

        spikes_depths = np.hstack(clusters_spk_depths)
        spikes_times = np.hstack(clusters_spk_times)
        spikes_clusters = np.hstack(
            [[cluster_id]*len(cluster_spk_depths)
             for (cluster_id, cluster_spk_depths) in zip(clusters_ids,
                                                         clusters_spk_depths)])

        if key['event'] == 'movement':
            q = behavior.TrialSet.Trial * wheel.MovementTimes & key & 'trial_feedback_type=1'
        else:
            q = behavior.TrialSet.Trial & key & 'trial_feedback_type=1'

        trials = q.fetch()

        bin_size_depth = 80
        min_depth = np.nanmin(spikes_depths)
        max_depth = np.nanmax(spikes_depths)
        bin_edges = np.arange(min_depth, max_depth, bin_size_depth)
        spk_bin_ids = np.digitize(spikes_depths, bin_edges)

        edges = np.hstack([bin_edges, [bin_edges[-1]+bin_size_depth]])
        key.update(trial_type='Correct All',
                   depth_bin_centers=(edges[:-1] + edges[1:])/2)

        if key['event'] == 'feedback':
            event_times = trials['trial_feedback_time']
        elif key['event'] == 'stim on':
            event_times = trials['trial_stim_on_time']
        elif key['event'] == 'movement':
            event_times = trials['movement_onset']

        peth_list = []
        baseline_list = []

        for i in tqdm(np.arange(len(bin_edges)) + 1, position=0):
            f = spk_bin_ids == i
            spikes_ibin = spikes_times[f]
            spike_clusters = spikes_clusters[f]
            cluster_ids = np.unique(spike_clusters)

            peths, binned_spikes = singlecell.calculate_peths(
                spikes_ibin, spike_clusters, cluster_ids,
                event_times, pre_time=0.3, post_time=1)
            if len(peths.means):
                time = peths.tscale
                peth = np.sum(peths.means, axis=0)
                baseline = peth[np.logical_and(time > -0.3, time < 0)]
                mean_bsl = np.mean(baseline)

                peth_list.append(peth)
                baseline_list.append(mean_bsl)
            else:
                peth_list.append(np.zeros_like(peths.tscale))
                baseline_list.append(0)

        key.update(depth_peth=np.vstack(peth_list),
                   depth_baseline=np.array(baseline_list),
                   time_bin_centers=peths.tscale)
        self.insert1(key, skip_duplicates=True)
Пример #11
0
    def _fetch_records(
        query,
        restriction: list = [],
        limit: int = 1000,
        page: int = 1,
        order=None,
        fetch_blobs=False,
        fetch_args=[],
    ) -> tuple:
        """
        Get records from query.

        :param query: any datajoint object related to QueryExpression
        :type query: datajoint ``QueryExpression`` or related object
        :param restriction: Sequence of filters as ``dict`` with ``attributeName``,
            ``operation``, ``value`` keys defined, defaults to ``[]``
        :type restriction: list, optional
        :param limit: Max number of records to return, defaults to ``1000``
        :type limit: int, optional
        :param page: Page number to return, defaults to ``1``
        :type page: int, optional
        :param order: Sequence to order records, defaults to ``['KEY ASC']``. See
            :class:`~datajoint.fetch.Fetch` for more info.
        :type order: list, optional
        :return: Attribute headers, records in dict form, and the total number of records that
            can be paged
        :rtype: tuple
        """

        # Get table object from name
        attributes = query.heading.attributes
        # Fetch tuples without blobs as dict to be used to create a
        #   list of tuples for returning
        query_restricted = query & dj.AndList([
            _DJConnector._filter_to_restriction(
                f, attributes[f["attributeName"]].type) for f in restriction
        ])

        order_by = (fetch_args.pop("order_by")
                    if "order_by" in fetch_args else ["KEY ASC"])
        order_by = order if order else order_by

        limit = fetch_args.pop("limit") if "limit" in fetch_args else limit

        if fetch_blobs and not fetch_args:
            fetch_args = [*query.heading.attributes]
        elif not fetch_args:
            fetch_args = query.heading.non_blobs
        else:
            attributes = {
                k: v
                for k, v in attributes.items() if k in fetch_args
            }
        non_blobs_rows = query_restricted.fetch(
            *fetch_args,
            as_dict=True,
            limit=limit,
            offset=(page - 1) * limit,
            order_by=order_by,
        )

        # Buffer list to be return
        rows = []

        # Looped through each tuple and deal with TEMPORAL types and replacing
        #   blobs with ==BLOB== for json encoding
        for non_blobs_row in non_blobs_rows:
            # Buffer object to store the attributes
            row = []
            # Loop through each attributes, append to the tuple_to_return with specific
            #   modification based on data type
            for attribute_name, attribute_info in attributes.items():
                if not attribute_info.is_blob:
                    if non_blobs_row[attribute_name] is None:
                        # If it is none then just append None
                        row.append(None)
                    elif attribute_info.type == "date":
                        # Date attribute type covert to epoch time
                        row.append((non_blobs_row[attribute_name] -
                                    datetime.date(1970, 1, 1)).days * DAY)
                    elif attribute_info.type == "time":
                        # Time attirbute, return total seconds
                        row.append(
                            non_blobs_row[attribute_name].total_seconds())
                    elif re.match(r"^datetime.*$",
                                  attribute_info.type) or re.match(
                                      r"timestamp", attribute_info.type):
                        # Datetime or timestamp, use timestamp to covert to epoch time
                        row.append(non_blobs_row[attribute_name].timestamp())
                    elif attribute_info.type[0:7] == "decimal":
                        # Covert decimal to string
                        row.append(str(non_blobs_row[attribute_name]))
                    else:
                        # Normal attribute, just return value with .item to deal with numpy
                        #   types
                        if isinstance(non_blobs_row[attribute_name],
                                      np.generic):
                            row.append(
                                np.asscalar(non_blobs_row[attribute_name]))
                        else:
                            row.append(non_blobs_row[attribute_name])
                else:
                    # Attribute is blob type thus fill it in string instead
                    (row.append(non_blobs_row[attribute_name])
                     if fetch_blobs else row.append("=BLOB="))
            # Add the row list to tuples
            rows.append(row)
        return list(attributes.keys()), rows, len(query_restricted)