예제 #1
0
    def load_traces_and_frametimes(self, key):
        # -- find number of recording depths
        pipe = (fuse.Activity() & key).fetch('pipe')
        assert len(
            np.unique(pipe)) == 1, 'Selection is from different pipelines'
        pipe = dj.create_virtual_module(pipe[0], 'pipeline_' + pipe[0])
        k = dict(key)
        k.pop('field', None)
        ndepth = len(dj.U('z') & (pipe.ScanInfo.Field() & k))
        frame_times = (stimulus.Sync()
                       & key).fetch1('frame_times').squeeze()[::ndepth]

        soma = pipe.MaskClassification.Type() & dict(type='soma')

        spikes = (dj.U('field', 'channel') * pipe.Activity.Trace() * StaticScan.Unit() \
                  * pipe.ScanSet.UnitInfo() & soma & key)
        traces, ms_delay, trace_keys = spikes.fetch(
            'trace',
            'ms_delay',
            dj.key,
            order_by='animal_id, session, scan_idx, unit_id')
        delay = np.fromiter(ms_delay / 1000, dtype=np.float)
        frame_times = (delay[:, None] + frame_times[None, :])
        traces = np.vstack(
            [fill_nans(tr.astype(np.float32)).squeeze() for tr in traces])
        traces, frame_times = self.adjust_trace_len(traces, frame_times)
        return traces, frame_times, trace_keys
def test_issue484():
    q = dj.U().aggr(S, n='max(s)')
    n = q.fetch('n')
    n = q.fetch1('n')
    q = dj.U().aggr(S, n='avg(s)')
    result = dj.U().aggr(q, m='max(n)')
    result.fetch()
예제 #3
0
def get_session_history(session_key, remove_ignored=True):
    # Fetch data
    q_choice_outcome = (experiment.WaterPortChoice.proj(choice='water_port')
                        * experiment.BehaviorTrial.proj('outcome', 'early_lick')
                        * experiment.SessionBlock.BlockTrial) & session_key
    if remove_ignored:
        q_choice_outcome &= 'outcome != "ignore"'

    # TODO: session QC (warm-up and decreased motivation etc.)

    # -- Choice and reward --
    # 0: left, 1: right, np.nan: ignored
    _choice = q_choice_outcome.fetch('choice', order_by='trial')
    _choice[_choice == 'left'] = 0
    _choice[_choice == 'right'] = 1

    _reward = q_choice_outcome.fetch('outcome', order_by='trial') == 'hit'
    reward_history = np.zeros([2, len(_reward)])  # .shape = (2, N trials)
    for c in (0, 1):
        reward_history[c, _choice == c] = (_reward[_choice == c] == True).astype(int)
        
    if remove_ignored:  # For model fitting, turn to integer
        choice_history = np.array([_choice]).astype(int)  # .shape = (1, N trials)
    else:  # Include np.NaNs, can't be integers
        choice_history = np.array([_choice]).astype(float)  # .shape = (1, N trials)
    
    # -- ITI --
    # All previous models has an effective ITI of constant 1.
    # For Ulises RNN model, an important prediction is that values in line attractor decay over (actual) time with time constant tau_CANN.
    # This will take into consideration (1) different ITI of each trial, and (2) long effective ITI after ignored trials.
    # Thus for CANN model, the ignored trials are also removed, but they contribute to model fitting in increasing the ITI.
    
    if (len(ephys.TrialEvent & q_choice_outcome) 
        and len(dj.U('trial') & (ephys.TrialEvent & q_choice_outcome)) == len(dj.U('trial') & (experiment.TrialEvent & q_choice_outcome))):
        # Use NI times (trial start and trial end) if (1) ephys exists (2) len(ephys) == len(behavior)
        trial_start = (ephys.TrialEvent & q_choice_outcome & 'trial_event_type = "bitcodestart"'
                       ).fetch('trial_event_time', order_by='trial').astype(float)
        trial_end = (ephys.TrialEvent & q_choice_outcome & 'trial_event_type = "trialend"'
                     ).fetch('trial_event_time', order_by='trial').astype(float)
        iti = trial_start[1:] - trial_end[:-1]  # ITI [t] --> ITI between trial t-1 and t
    else:  # If not ephys session, we can only use PC time (should be fine because this fitting is not quite time-sensitive)
        bpod_start_global = (experiment.SessionTrial & q_choice_outcome
                             ).fetch('start_time', order_by='trial').astype(float)  # This is (global) PC time
        bitcodestart_local = (experiment.TrialEvent & q_choice_outcome & 'trial_event_type = "bitcodestart"'
                              ).fetch('trial_event_time', order_by='trial').astype(float)
        trial_end_local = (experiment.TrialEvent & q_choice_outcome & 'trial_event_type = "trialend"'
                           ).fetch('trial_event_time', order_by='trial').astype(float)
        if len(bitcodestart_local) == len(trial_end_local) and len(bitcodestart_local) > 0:
            iti = (bpod_start_global[1:] + bitcodestart_local[1:]) - (bpod_start_global[:-1] + trial_end_local[:-1])
        else:
            iti = bpod_start_global[1:] - bpod_start_global[:-1]
        
    iti = np.hstack([0, iti])  # First trial iti is irrelevant
    
    # -- p_reward --
    q_p_reward = q_choice_outcome.proj() * experiment.SessionBlock.WaterPortRewardProbability & session_key
    p_reward = np.vstack([(q_p_reward & 'water_port="left"').fetch('reward_probability', order_by='trial').astype(float),
                          (q_p_reward & 'water_port="right"').fetch('reward_probability', order_by='trial').astype(float)])

    return choice_history, reward_history, iti, p_reward, q_choice_outcome
예제 #4
0
def get_extended_qc_fields_from_alyx(level='session'):
    if level == 'session':
        key_source = dj.U('uuid') & \
            (alyxraw.AlyxRaw.Field &
             (alyxraw.AlyxRaw & 'model="actions.session"') &
             'fname="extended_qc"' &
             'fvalue!="None"')

        fname = 'extended_qc'

    elif level == 'probe':

        key_source = dj.U('uuid') & \
            (alyxraw.AlyxRaw.Field &
             (alyxraw.AlyxRaw & 'model="experiments.probeinsertion"') &
             'fname="json"' &
             'fvalue like "%extended_qc%"')
        fname = 'json'
    else:
        raise ValueError('Incorrect level argument, has to be "session" or "probe"')

    eqc_fields = []

    for key in tqdm(key_source):
        qc_extended = str_to_dict(grf(key, fname))

        if qc_extended != 'None':
            if level == 'probe' and 'extended_qc' in qc_extended:
                qc_extended = qc_extended['extended_qc']

            eqc_fields += list(qc_extended.keys())

    return set(eqc_fields)
예제 #5
0
def print_current_jobs():
    """
    Return a pandas.DataFrame on the status of each table currently being processed

        table | reserve_count | error_count | oldest_job | newest_job
            - table: {schema}.{table} name
            - reserve_count: number of workers currently working on this table
            - error_count: number of jobs errors for this table
            - oldest_job: timestamp of the oldest job currently being worked on
            - newest_job: timestamp of the most recent job currently being worked on

    Provide insights into the current status of the workers

    One caveat in this function is that we don't know how many workers are being deployed,
     and how they're orchestrated. We can infer by taking the sum of the reserved jobs,
     but this won't reflect idling workers because there's no "key_source" to work on for some
     particular tables
    """
    job_status = []
    for pipeline_module in (experiment, tracking, ephys, histology, psth,
                            foraging_analysis, oralfacial_analysis, report):
        reserved = dj.U('table_name').aggr(pipeline_module.schema.jobs
                                           & 'status = "reserved"',
                                           reserve_count='count(table_name)',
                                           oldest_job='MIN(timestamp)',
                                           newest_job='MAX(timestamp)')
        errored = dj.U('table_name').aggr(pipeline_module.schema.jobs
                                          & 'status = "error"',
                                          error_count='count(table_name)')
        if dj.__version__.startswith('0.13'):
            jobs_summary = reserved.join(errored, left=True)
        else:
            jobs_summary = reserved.aggr(errored, ...,
                                         error_count='error_count',
                                         keep_all_rows=True)

        for job in jobs_summary.fetch(as_dict=True):
            job_status.append({
                'table':
                f'{pipeline_module.__name__.split(".")[-1]}.{dj.utils.to_camel_case(job.pop("table_name"))}',
                **job
            })

    if not job_status:
        print('No jobs under process (0 reserved jobs)')
        return

    job_status_df = pd.DataFrame(job_status).set_index('table')
    job_status_df.fillna(0, inplace=True)
    job_status_df = job_status_df.astype({
        "reserve_count": int,
        "error_count": int
    })

    with pd.option_context('display.max_rows', None, 'display.max_columns',
                           None, 'display.width', None, 'display.max_colwidth',
                           -1):
        print(job_status_df)

    return job_status_df
예제 #6
0
    def fill_up(self, tier, frames, cond, key, m):
        existing = ConditionTier().proj() & (self & dict(tier=tier)) \
                   & (stimulus.Trial() * stimulus.Condition() & dict(key, **cond))
        n = len(existing)
        if n < m:
            # all hashes that are in clips but not registered for that animal and have the right tier
            candidates = dj.U('condition_hash') & \
                         (self & (dj.U('condition_hash') & (frames - self)) & dict(tier=tier))
            keys = candidates.fetch(dj.key)
            d = m - n
            update = min(len(keys), d)

            log.info('Inserting {} more existing {} trials'.format(
                update, tier))
            for k in keys[:update]:
                k = (frames & k).fetch1(dj.key)
                k['tier'] = tier
                self.insert1(k, ignore_extra_fields=True)

        existing = ConditionTier().proj() & (self & dict(tier=tier)) \
                   & (stimulus.Trial() * stimulus.Condition() & dict(key, **cond))
        n = len(existing)
        if n < m:
            keys = (frames - self).fetch(dj.key)
            update = m - n
            log.info('Inserting {} more new {} trials'.format(update, tier))

            for k in keys[:update]:
                k['tier'] = tier
                self.insert1(k, ignore_extra_fields=True)
    def test_join(self):
        rel = self.experiment*dj.U('experiment_date')
        assert_equal(self.experiment.primary_key, ['subject_id', 'experiment_id'])
        assert_equal(rel.primary_key, self.experiment.primary_key + ['experiment_date'])

        rel = dj.U('experiment_date')*self.experiment
        assert_equal(self.experiment.primary_key, ['subject_id', 'experiment_id'])
        assert_equal(rel.primary_key, self.experiment.primary_key + ['experiment_date'])
예제 #8
0
 def test_aggr(self):
     rel = schema_simple.ArgmaxTest()
     amax1 = (dj.U('val') * rel) & dj.U('secondary_key').aggr(
         rel, val='min(val)')
     amax2 = (dj.U('val') * rel) * dj.U('secondary_key').aggr(
         rel, val='min(val)')
     assert_true(
         len(amax1) == len(amax2) == rel.n,
         'Aggregated argmax with join and restriction does not yield same length.'
     )
예제 #9
0
 def test_aggr(self):
     rel = schema_simple.ArgmaxTest()
     amax1 = (dj.U("val") * rel) & dj.U("secondary_key").aggr(
         rel, val="min(val)")
     amax2 = (dj.U("val") * rel) * dj.U("secondary_key").aggr(
         rel, val="min(val)")
     assert_true(
         len(amax1) == len(amax2) == rel.n,
         "Aggregated argmax with join and restriction does not yield same length.",
     )
예제 #10
0
 def test_aggregations(self):
     lang = schema.Language()
     # test total aggregation on expression object
     n1 = dj.U().aggr(lang, n='count(*)').fetch1('n')
     assert_equal(n1, len(lang.fetch()))
     # test total aggregation on expression class
     n2 = dj.U().aggr(schema.Language, n='count(*)').fetch1('n')
     assert_equal(n1, n2)
     rel = dj.U('language').aggr(schema.Language, number_of_speakers='count(*)')
     assert_equal(len(rel), len(set(l[1] for l in schema.Language.contents)))
     assert_equal((rel & 'language="English"').fetch1('number_of_speakers'), 3)
예제 #11
0
 def test_restriction(self):
     language_set = {s[1] for s in self.language.contents}
     rel = dj.U('language') & self.language
     assert_list_equal(rel.heading.names, ['language'])
     assert_true(len(rel) == len(language_set))
     assert_true(set(rel.fetch('language')) == language_set)
     # Test for issue #342
     rel = self.trial*dj.U('start_time')
     assert_list_equal(rel.primary_key, self.trial.primary_key + ['start_time'])
     assert_list_equal(rel.primary_key, (rel & 'trial_id>3').primary_key)
     assert_list_equal((dj.U('start_time') & self.trial).primary_key, ['start_time'])
예제 #12
0
    def test_join(self):
        rel = self.experiment * dj.U("experiment_date")
        assert_equal(self.experiment.primary_key,
                     ["subject_id", "experiment_id"])
        assert_equal(rel.primary_key,
                     self.experiment.primary_key + ["experiment_date"])

        rel = dj.U("experiment_date") * self.experiment
        assert_equal(self.experiment.primary_key,
                     ["subject_id", "experiment_id"])
        assert_equal(rel.primary_key,
                     self.experiment.primary_key + ["experiment_date"])
예제 #13
0
    def plot_calibration_curve(self, calibration_date, rig):
        import matplotlib.pyplot as plt
        import seaborn as sns
        session = LaserCalibration.PowerMeasurement() & dict(
            calibration_date=calibration_date, rig=rig)
        sns.set_context('talk')
        with sns.axes_style('darkgrid'):
            fig, ax = plt.subplots()

        # sns.set_palette("husl")

        for k in (dj.U('pockels', 'bidirectional', 'gdd', 'wavelength')
                  & session).fetch.keys():
            pe, po, zoom = (session & k).fetch('percentage', 'power', 'zoom')
            zoom = np.unique(zoom)
            ax.plot(pe,
                    po,
                    'o-',
                    label=(u"zoom={0:.2f} ".format(zoom[0]) +
                           " ".join("{0}={1}".format(*v) for v in k.items())))
        ax.legend(loc='best')
        ax.set_xlim((0, 100))
        y_min, y_max = [np.round(y / 5) * 5 for y in ax.get_ylim()]
        ax.set_yticks(np.arange(0, y_max + 5, 5))
        ax.set_xlabel('power [in %]')
        ax.set_ylabel('power [in mW]')

        return fig, ax
예제 #14
0
 def check_train_test_split(self, frames, cond):
     stim = getattr(stimulus, cond['stimulus_type'].split('.')[-1])
     train_test = (dj.U(*UNIQUE_FRAME[cond['stimulus_type']]).aggr(frames * stim,
                                                                   train='sum(1-test)',
                                                                   test='sum(test)') &
                   'train>0 and test>0')
     assert len(train_test) == 0, 'Train and test clips do overlap'
예제 #15
0
 def test_aggregations(self):
     rel = dj.U('language').aggr(schema.Language(),
                                 number_of_speakers='count(*)')
     assert_equal(len(rel),
                  len(set(l[1] for l in schema.Language.contents)))
     assert_equal((rel & 'language="English"').fetch1('number_of_speakers'),
                  3)
예제 #16
0
    def make(self, key):

        master_entry = key.copy()
        rt = key.copy()

        # get all trial sets and trials from that date
        trial_sets_proj = (behavior.TrialSet.proj(
            session_date='DATE(session_start_time)')) & key

        trial_sets_keys = (behavior.TrialSet * trial_sets_proj).fetch('KEY')

        n_trials, n_correct_trials = \
            (behavior.TrialSet & trial_sets_keys).fetch(
                'n_trials', 'n_correct_trials')

        trials = behavior.TrialSet.Trial & trial_sets_keys

        # compute the performance for easy trials
        performance_easy = utils.compute_performance_easy(trials)
        if performance_easy:
            master_entry['performance_easy'] = performance_easy

        # compute the performance for all trials
        master_entry['performance'] = np.divide(np.sum(n_correct_trials),
                                                np.sum(n_trials))

        self.insert1(master_entry)

        # compute psych results for all trials
        task_protocol = (acquisition.Session
                         & trial_sets_keys[0]).fetch1('task_protocol')

        if task_protocol and 'biased' in task_protocol:
            prob_lefts = dj.U('trial_stim_prob_left') & trials

            for ileft, prob_left in enumerate(prob_lefts):
                p_left = prob_left['trial_stim_prob_left']
                trials_sub = trials & \
                    'ABS(trial_stim_prob_left - {})<1e-6'.format(p_left)
                # compute psych results
                psych_results_tmp = utils.compute_psych_pars(trials_sub)
                psych_results = {**key, **psych_results_tmp}
                psych_results['prob_left'] = prob_left['trial_stim_prob_left']
                psych_results['prob_left_block'] = ileft
                self.PsychResults.insert1(psych_results)
                # compute reaction time
                rt['prob_left_block'] = ileft
                rt['reaction_time'] = utils.compute_reaction_time(trials_sub)
                self.ReactionTime.insert1(rt)
        else:
            psych_results_tmp = utils.compute_psych_pars(trials)
            psych_results = {**key, **psych_results_tmp}
            psych_results['prob_left'] = 0.5
            psych_results['prob_left_block'] = 1
            self.PsychResults.insert1(psych_results)

            # compute reaction time
            rt['prob_left_block'] = 1
            rt['reaction_time'] = utils.compute_reaction_time(trials)
            self.ReactionTime.insert1(rt)
예제 #17
0
def cell_matches(animal_id, size):

    key = dict(animal_id=animal_id, **SETTINGS)
    sz = tuple(i * size_factor[size] for i in [.7, .7])
    df = pd.DataFrame((stack.StackSet.Unit() * dj.U('stack_session')
                       & key).aggr(stack.StackSet.Match() & key,
                                   matches='COUNT(*)').fetch())

    with sns.plotting_context('talk' if size == 'huge' else 'paper',
                              font_scale=1.3):
        with sns.axes_style(style="ticks", rc={"axes.facecolor":
                                               (0, 0, 0, 0)}):
            order = sorted(pd.unique(df.matches))

            g = sns.factorplot('matches',
                               kind='count',
                               hue='stack_session',
                               data=df,
                               order=order)
            sns.despine(trim=True, offset=5)

            g.ax.spines['bottom'].set_linewidth(1)
            g.ax.spines['left'].set_linewidth(1)
            g.ax.tick_params(axis='both', length=3, width=1)
            g.ax.set_xlabel('scans neuron was visible in')
            g.ax.set_ylabel('neurons')
    g.fig.set_size_inches(sz)

    return savefig(g.fig)
예제 #18
0
    def create1_from_processing_task(self,
                                     key,
                                     is_curated=False,
                                     curation_note=""):
        """
        Given a "ProcessingTask", create a new corresponding "Curation"
        """
        if key not in Processing():
            raise ValueError(
                f"No corresponding entry in Processing available for: "
                f"{key}; run `Processing.populate(key)`")

        output_dir = (ProcessingTask & key).fetch1("processing_output_dir")
        method, imaging_dataset = get_loader_result(key, ProcessingTask)

        if method == "caiman":
            caiman_dataset = imaging_dataset
            curation_time = caiman_dataset.creation_time
        else:
            raise NotImplementedError("Unknown method: {}".format(method))

        # Synthesize curation_id
        curation_id = (dj.U().aggr(
            self & key, n="ifnull(max(curation_id)+1,1)").fetch1("n"))
        self.insert1({
            **key,
            "curation_id": curation_id,
            "curation_time": curation_time,
            "curation_output_dir": output_dir,
            "manual_curation": is_curated,
            "curation_note": curation_note,
        })
예제 #19
0
    def make(self, key):

        task_protocol = (acquisition.Session & key).fetch1('task_protocol')

        trials = behavior.TrialSet.Trial & key

        if task_protocol and ('biased' in task_protocol):
            prob_lefts = dj.U('trial_stim_prob_left') & trials

            for prob_left in prob_lefts:
                p_left = prob_left['trial_stim_prob_left']
                trials_sub = trials & \
                    'ABS(trial_stim_prob_left - {})<1e-6'.format(p_left)

                # compute psych results
                psych_results = utils.compute_psych_pars(trials_sub)
                psych_results = {**key, **psych_results}
                psych_results['prob_left'] = prob_left['trial_stim_prob_left']
                if abs(p_left - 0.8) < 0.001:
                    psych_results['prob_left_block'] = 80
                elif abs(p_left - 0.2) < 0.001:
                    psych_results['prob_left_block'] = 20
                elif abs(p_left - 0.5) < 0.001:
                    psych_results['prob_left_block'] = 50

                self.insert1(psych_results)

        else:
            psych_results = utils.compute_psych_pars(trials)
            psych_results = {**key, **psych_results}
            psych_results['prob_left'] = 0.5
            psych_results['prob_left_block'] = 50

            self.insert1(psych_results)
예제 #20
0
        def unique_unit_mapping(self, dynamic_scan):
            if (self * dv_nn9_scan.ScanConfig.Scan3() *
                    dv_scan3_scan_dataset.UnitConfig.Unique()):

                key = ((self * dv_nn9_scan.ScanConfig.Scan3() *
                        dv_scan3_scan_dataset.UnitConfig.Unique())
                       & dynamic_scan).fetch1()  # get unique_id
                unique_unit_key = (dv_scan3_scan.Unique() & dynamic_scan
                                   & key).fetch1("KEY")
                unique_unit_rel = (
                    dv_scan3_scan.Unique.Unit *
                    dv_scan3_scan.Unique.Neuron.proj(unique_unit_id="unit_id")
                    & unique_unit_key)
                return (dj.U("animal_id", "session", "scan_idx", "unit_id",
                             "unique_unit_id")
                        & unique_unit_rel)
            elif (self * dv_nn9_scan.ScanConfig.Scan3() *
                  dv_scan3_scan_dataset.UnitConfig.All()
                  ):  # return a mapping from all units to themselves
                units = (self * dv_nn9_scan.Scan.Unit *
                         dv_nn9_scan.ScanConfig().Scan3())
                return units.proj(unique_unit_id='unit_id * 1')
            else:
                raise NotImplementedError(
                    "`unique_unit_mapping` is not implemented for key {}!".
                    format(self.fetch1()))
예제 #21
0
    def test_dj_u_distinct(self):
        # Test developed to see if removing DISTINCT from the select statement
        # generation breakes the dj.U universal set imlementation

        # Contents to be inserted
        contents = [(1, 2, 3), (2, 2, 3), (3, 3, 2), (4, 5, 5)]
        Stimulus.insert(contents)

        # Query the whole table
        test_query = Stimulus()

        # Use dj.U to create a list of unique contrast and brightness combinations
        result = dj.U('contrast', 'brightness') & test_query
        expected_result = [{
            'contrast': 2,
            'brightness': 3
        }, {
            'contrast': 3,
            'brightness': 2
        }, {
            'contrast': 5,
            'brightness': 5
        }]

        fetched_result = result.fetch(as_dict=True,
                                      order_by=('contrast', 'brightness'))
        Stimulus.delete_quick()
        assert fetched_result == expected_result
예제 #22
0
    def make(self, key):
        task_protocol = (acquisition.Session & key).fetch1('task_protocol')

        trials = behavior.TrialSet.Trial & key & 'trial_stim_on_time is not NULL'

        if task_protocol and ('biased' in task_protocol):
            prob_lefts = dj.U('trial_stim_prob_left') & trials

            for prob_left in prob_lefts:
                rt = key.copy()
                p_left = prob_left['trial_stim_prob_left']
                trials_sub = trials & \
                    'ABS(trial_stim_prob_left - {})<1e-6'.format(p_left)

                # compute reaction_time
                rt['reaction_time_contrast'], rt['reaction_time_ci_low'], \
                    rt['reaction_time_ci_high'] = utils.compute_reaction_time(
                        trials_sub, compute_ci=True)

                if abs(p_left - 0.8) < 0.001:
                    rt['prob_left_block'] = 80
                elif abs(p_left - 0.2) < 0.001:
                    rt['prob_left_block'] = 20
                elif abs(p_left - 0.5) < 0.001:
                    rt['prob_left_block'] = 50

                self.insert1(rt)

        else:
            rt = key.copy()
            rt['prob_left_block'] = 50
            rt['reaction_time_contrast'], rt['reaction_time_ci_low'], \
                rt['reaction_time_ci_high'] = utils.compute_reaction_time(
                    trials, compute_ci=True)
            self.insert1(rt)
예제 #23
0
    def create1_from_processing_task(self,
                                     key,
                                     is_curated=False,
                                     curation_note=''):
        """
        A convenient function to create a new corresponding "Curation" for a particular "ProcessingTask"
        """
        if key not in Processing():
            raise ValueError(
                f'No corresponding entry in Processing available for: {key};'
                f' do `Processing.populate(key)`')

        output_dir = (ProcessingTask & key).fetch1('processing_output_dir')
        method, loaded_result = get_loader_result(key, ProcessingTask)

        if method == 'caiman':
            loaded_caiman = loaded_result
            curation_time = loaded_caiman.creation_time
        elif method == 'mcgill_miniscope_analysis':
            loaded_miniscope_analysis = loaded_result
            curation_time = loaded_miniscope_analysis.creation_time
        else:
            raise NotImplementedError('Unknown method: {}'.format(method))

        # Synthesize curation_id
        curation_id = dj.U().aggr(self & key,
                                  n='ifnull(max(curation_id)+1,1)').fetch1('n')
        self.insert1({
            **key, 'curation_id': curation_id,
            'curation_time': curation_time,
            'curation_output_dir': output_dir,
            'manual_curation': is_curated,
            'curation_note': curation_note
        })
def plot_example_cells(sort_lv = 'relative_action_value_ic', 
                       sort_ep = 'iti_all',
                       best_n = 10, linear_model='Q_rel + Q_tot + rpe'):
    
#%%
    q_unit = ((ephys.Unit * ephys.ClusterMetric * ephys.UnitStat * ephys.MAPClusterMetric.DriftMetric)
              & 'presence_ratio > 0.95'
              & 'amplitude_cutoff < 0.1'
              & 'isi_violation < 0.5' 
              & 'unit_amp > 100'
              # & 'drift_metric < 0.1'
              )

    q_hist = (q_unit * histology.ElectrodeCCFPosition.ElectrodePosition) * ccf.CCFAnnotation
    q_unit_n = dj.U('annotation').aggr(q_hist, area_num_units='count(*)')
    q_hist *= q_unit_n

    lvs = (psth_foraging.LinearModel.X & {'multi_linear_model': linear_model}).fetch('var_name')
    q_all = ((psth_foraging.UnitPeriodLinearFit
              * psth_foraging.UnitPeriodLinearFit.Param
              * q_hist)
              & {'multi_linear_model': linear_model})   

    # Best n (absolute value)
    best_models = (q_all & f'var_name = "{sort_lv}"' & f'period = "{sort_ep}"').proj(
        'actual_behavior_model', abs_t='abs(t)').fetch(order_by='abs_t desc', limit=best_n, format='frame')

    for unit_key in best_models.reset_index().to_dict('records'):
        unit_psth.plot_unit_psth_choice_outcome(unit_key)
        unit_psth.plot_unit_psth_latent_variable_quantile(unit_key, 
            model_id=unit_key['actual_behavior_model'])
        unit_psth.plot_unit_period_tuning(unit_key)
예제 #25
0
def fetch_oracle_raster(unit_key):
    """Fetches the responses of the provided unit to the oracle trials
    Args:
        unit_key      (dict):        dictionary to uniquely identify a functional unit (must contain the keys: "session", "scan_idx", "unit_id") 
        
    Returns:
        oracle_score (float):        
        responses    (array):        array of oracle responses interpolated to scan frequency: 10 repeats x 6 oracle clips x f response frames
    """
    fps = (nda.Scan & unit_key).fetch1('fps') # get frame rate of scan

    oracle_rel = (dj.U('condition_hash').aggr(nda.Trial & unit_key,n='count(*)',m='min(trial_idx)') & 'n=10') # get oracle clips
    oracle_hashes = oracle_rel.fetch('KEY',order_by='m ASC') # get oracle clip hashes sorted temporally

    frame_times_set = []
    # iterate over oracle repeats (10 repeats)
    for first_clip in (nda.Trial & oracle_hashes[0] & unit_key).fetch('trial_idx'): 
        trial_block_rel = (nda.Trial & unit_key & f'trial_idx >= {first_clip} and trial_idx < {first_clip+6}') # uses the trial_idx of the first clip to grab subsequent 5 clips (trial_block) 
        start_times, end_times = trial_block_rel.fetch('start_frame_time', 'end_frame_time', order_by='condition_hash DESC') # grabs start time and end time of each clip in trial_block and orders by condition_hash to maintain order across scans
        frame_times = [np.linspace(s, e , np.round(fps * (e - s)).astype(int)) for s, e in zip(start_times, end_times)] # generate time vector between start and end times according to frame rate of scan
        frame_times_set.append(frame_times)

    trace, fts, delay = ((nda.Activity & unit_key) * nda.FrameTimes * nda.ScanUnit).fetch1('trace', 'frame_times', 'ms_delay') # fetch trace delay and frame times for interpolation
    f2a = interp1d(fts + delay/1000, trace) # create trace interpolator with unit specific time delay
    oracle_traces = np.array([f2a(ft) for ft in frame_times_set]) # interpolate oracle times to match the activity trace
    oracle_traces -= np.min(oracle_traces,axis=(1,2),keepdims=True) # normalize the oracle traces
    oracle_traces /= np.max(oracle_traces,axis=(1,2),keepdims=True) # normalize the oracle traces
    oracle_score = (nda.Oracle & unit_key).fetch1('pearson') # fetch oracle score
    return oracle_traces, oracle_score
예제 #26
0
    def create1_from_clustering_task(self, key, curation_note=''):
        """
        A convenient function to create a new corresponding "Curation"
         for a particular "ClusteringTask"
        """
        if key not in Clustering():
            raise ValueError(f'No corresponding entry in Clustering available'
                             f' for: {key}; do `Clustering.populate(key)`')

        task_mode, output_dir = (ClusteringTask & key).fetch1(
            'task_mode', 'clustering_output_dir')
        kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)

        creation_time, is_curated, is_qc = kilosort.extract_clustering_info(
            kilosort_dir)
        # Synthesize curation_id
        curation_id = dj.U().aggr(self & key,
                                  n='ifnull(max(curation_id)+1,1)').fetch1('n')
        self.insert1({
            **key, 'curation_id': curation_id,
            'curation_time': creation_time,
            'curation_output_dir': output_dir,
            'quality_control': is_qc,
            'manual_curation': is_curated,
            'curation_note': curation_note
        })
예제 #27
0
def get_fields(chosen_cell, scans, stack_key):
    """
    function that takes in an munit_id, a set of scans, and a stack, and returns the field key and a set of summary images for each scan_session and scan_idx the munit appears in:
    The summary images are as follows:
    1) scan summary image which will be one of: average, correlation, l6norm, hybrid image depending on specification in argument 'functional_image'
    2) the relevant stack field after registering the imaging field inside the stack
    3) the relevant 3D segmentation image after registering the imaging field
    
    :param chosen_cell: dictionary of munit id formatted as such: {'munit_id':00000}
    :param scans: the key to restrict the stack datajoint table relation to the specified functional scans
    :param stack_key: the relevant stack to restrict with

    :return: List of field keys For each scan_session and scan_idx returns:
    """

    field_munit_relation_table = meso.ScanSet.Unit * (
        stack.StackSet.Match & stack_key & chosen_cell).proj(
            'munit_id', session='scan_session') & scans
    field_munit_relation_table

    # choose an example cell and get unit_id's (unique id's per scan)
    field_munit_relation_keys = (
        dj.U('animal_id', 'stack_session', 'stack_idx', 'segmentation_method',
             'session', 'scan_idx', 'field', 'unit_id', 'munit_id')
        & field_munit_relation_table).fetch('KEY')

    return field_munit_relation_keys
def concatenated_rel(cls,
                     core_segment=None,
                     version=-1,
                     return_with_meshes=False):
    """
    Returns all. You can restrict by a core_segment first though.

    :param core_segment: The core segment(s) to restrict by. If left empty will fetch all.
    :param version: The default of -1 will fetch the highest version for each core segment
        and its subsegments. If you happen to explicitely pass False, it will ignore version.
    :param return_with_meshes: When set to true or 'Decimation' will default
        to using the Decimation table for the meshes, otherwise 'Mesh' will
        choose the Mesh table with the original meshes.
    """

    subsegment_rel = cls.Subsegment.proj()

    if core_segment is not None:
        try:
            subsegment_rel &= [
                dict(segment_id=segment_id) for segment_id in core_segment
            ]
        except TypeError:
            subsegment_rel &= dict(segment_id=core_segment)

    if version == -1:
        version_rel = dj.U('segment_id').aggr(subsegment_rel,
                                              version='max(version)')
    elif version is False:
        version_rel = subsegment_rel
    else:
        version_rel = subsegment_rel & dict(version=version)

    a_rel = dj.U('segment_id') & version_rel
    b_rel = dj.U('segment_id') & (subsegment_rel & version_rel).proj(
        _='segment_id', segment_id='subsegment_id')
    c_rel = a_rel + b_rel

    if return_with_meshes:
        if isinstance(return_with_meshes,
                      str) and return_with_meshes.lower() == 'mesh':
            c_rel = minnie.Mesh & c_rel
        else:
            c_rel = minnie.Decimation & c_rel

    return c_rel
예제 #29
0
 def load_frame_times(self, key):
     pipe = (fuse.Activity() & key).fetch('pipe')
     assert len(np.unique(pipe)) == 1, 'Selection is from different pipelines'
     pipe = dj.create_virtual_module(pipe[0], 'pipeline_' + pipe[0])
     k = dict(key)
     k.pop('field', None)
     ndepth = len(dj.U('z') & (pipe.ScanInfo.Field() & k))
     return (stimulus.Sync() & key).fetch1('frame_times').squeeze()[::ndepth]
예제 #30
0
    def key_source(self):
        ret_rel = stack.Area.proj(ret_session='scan_session',
                                  ret_scan_idx='scan_idx',
                                  ret_channel='scan_channel')
        key_source = ret_rel * StackCoordinates
        heading_str = list(self.heading)

        return dj.U(*heading_str) & key_source