示例#1
0
def main():

    for expt in session.query(models.Experiment):
        print('In experiment "{}"...'.format(expt.id))

        for odor_state in ODOR_STATES:
            print('Odor state = "{}"'.format(odor_state))

            trajs = session.query(models.Trajectory).\
                filter_by(experiment=expt, odor_state=odor_state, clean=True)

            for variable_name in QUANTITIES:
                print('{}...'.format(variable_name))

                traj_data = []

                traj_ctr = 0
                for traj in trajs:

                    traj_data.extend(
                        traj.timepoint_field(session, variable_name))

                    traj_ctr += 1

                lb, ub = None, None

                if variable_name.endswith('_a') or 'heading' in variable_name:
                    lb = 0
                    if 'heading' in variable_name:
                        ub = 180

                cts, bins = make_distribution(np.array(traj_data),
                                              N_BINS,
                                              lb=lb,
                                              ub=ub)

                file_name = '{}_{}_{}.pickle'.format(expt.id, odor_state,
                                                     variable_name)

                tp_dstr = models.TimepointDistribution(
                    figure_root_path_env_var=figure_data_env_var,
                    directory_path=DIRECTORY_PATH,
                    file_name=file_name,
                    variable=variable_name,
                    experiment_id=expt.id,
                    odor_state=odor_state,
                    n_data_points=len(traj_data),
                    n_trajectories=traj_ctr,
                    bin_min=bins[0],
                    bin_max=bins[-1],
                    n_bins=N_BINS)
                tp_dstr.data = {'cts': cts, 'bins': bins}
                session.add(tp_dstr)

                commit(session)
示例#2
0
def main():
    # loop over all experiments
    for expt in session.query(models.Experiment):

        print("In experiment '{}'".format(expt.id))

        dt = 1 / expt.sampling_frequency

        for traj in session.query(
                models.Trajectory).filter_by(experiment=expt):

            positions = traj.positions(session)
            velocities = traj.velocities(session)

            # calculate kinematic quantities
            velocities_a = kinematics.norm(velocities)

            accelerations = kinematics.acceleration(velocities, dt)
            accelerations_a = kinematics.norm(accelerations)

            headings = kinematics.heading(velocities)

            angular_velocities = kinematics.angular_velocity(velocities, dt)
            angular_velocities_a = kinematics.norm(angular_velocities)

            angular_accelerations = kinematics.acceleration(
                angular_velocities, dt)
            angular_accelerations_a = kinematics.norm(angular_accelerations)

            distance_from_wall = kinematics.distance_from_wall(
                positions, WALL_BOUNDS)

            # store kinematic quantities in timepoints
            for ctr, tp in enumerate(traj.timepoints(session)):

                tp.velocity_a = velocities_a[ctr]
                tp.acceleration_x, tp.acceleration_y, tp.acceleration_z = accelerations[
                    ctr]
                tp.acceleration_a = accelerations_a[ctr]

                tp.heading_xy, tp.heading_xz, tp.heading_xyz = headings[ctr]

                tp.angular_velocity_x, tp.angular_velocity_y, tp.angular_velocity_z = angular_velocities[
                    ctr]
                tp.angular_velocity_a = angular_velocities_a[ctr]

                tp.angular_acceleration_x, tp.angular_acceleration_y, tp.angular_acceleration_z = angular_accelerations[
                    ctr]
                tp.angular_acceleration_a = angular_accelerations_a[ctr]

                tp.distance_from_wall = distance_from_wall[ctr]

                session.add(tp)

            commit(session)
def main():

    for expt in session.query(models.Experiment):
        threshold = session.query(models.Threshold).filter_by(experiment=expt,
                                                              determination=DETERMINATION).first()
        for cg in threshold.crossing_groups:
            print(cg.id)
            for crossing in cg.crossings:

                position_x_entry = crossing.timepoint_field(session, 'position_x', 0, 0, 'entry', 'entry')[0]
                position_y_entry = crossing.timepoint_field(session, 'position_y', 0, 0, 'entry', 'entry')[0]
                position_z_entry = crossing.timepoint_field(session, 'position_z', 0, 0, 'entry', 'entry')[0]

                position_x_peak = crossing.timepoint_field(session, 'position_x', 0, 0, 'peak', 'peak')[0]
                position_y_peak = crossing.timepoint_field(session, 'position_y', 0, 0, 'peak', 'peak')[0]
                position_z_peak = crossing.timepoint_field(session, 'position_z', 0, 0, 'peak', 'peak')[0]

                position_x_exit = crossing.timepoint_field(session, 'position_x', 0, 0, 'exit', 'exit')[0]
                position_y_exit = crossing.timepoint_field(session, 'position_y', 0, 0, 'exit', 'exit')[0]
                position_z_exit = crossing.timepoint_field(session, 'position_z', 0, 0, 'exit', 'exit')[0]

                heading_xy_entry = crossing.timepoint_field(session, 'heading_xy', 0, 0, 'entry', 'entry')[0]
                heading_xz_entry = crossing.timepoint_field(session, 'heading_xz', 0, 0, 'entry', 'entry')[0]
                heading_xyz_entry = crossing.timepoint_field(session, 'heading_xyz', 0, 0, 'entry', 'entry')[0]

                heading_xy_peak = crossing.timepoint_field(session, 'heading_xy', 0, 0, 'peak', 'peak')[0]
                heading_xz_peak = crossing.timepoint_field(session, 'heading_xz', 0, 0, 'peak', 'peak')[0]
                heading_xyz_peak = crossing.timepoint_field(session, 'heading_xyz', 0, 0, 'peak', 'peak')[0]

                heading_xy_exit = crossing.timepoint_field(session, 'heading_xy', 0, 0, 'exit', 'exit')[0]
                heading_xz_exit = crossing.timepoint_field(session, 'heading_xz', 0, 0, 'exit', 'exit')[0]
                heading_xyz_exit = crossing.timepoint_field(session, 'heading_xyz', 0, 0, 'exit', 'exit')[0]

                crossing.feature_set_basic = models.CrossingFeatureSetBasic(position_x_entry=position_x_entry,
                                                                            position_y_entry=position_y_entry,
                                                                            position_z_entry=position_z_entry,
                                                                            position_x_peak=position_x_peak,
                                                                            position_y_peak=position_y_peak,
                                                                            position_z_peak=position_z_peak,
                                                                            position_x_exit=position_x_exit,
                                                                            position_y_exit=position_y_exit,
                                                                            position_z_exit=position_z_exit,
                                                                            heading_xy_entry=heading_xy_entry,
                                                                            heading_xz_entry=heading_xz_entry,
                                                                            heading_xyz_entry=heading_xyz_entry,
                                                                            heading_xy_peak=heading_xy_peak,
                                                                            heading_xz_peak=heading_xz_peak,
                                                                            heading_xyz_peak=heading_xyz_peak,
                                                                            heading_xy_exit=heading_xy_exit,
                                                                            heading_xz_exit=heading_xz_exit,
                                                                            heading_xyz_exit=heading_xyz_exit)

                session.add(crossing)
                commit(session)
示例#4
0
def crossing_number_distributions(
        CROSSING_GROUPS, AX_GRID):
    """
    Plot histograms of the number of odor crossings per trajectory.
    """

    crossing_numbers_all = {}

    for cg_id in CROSSING_GROUPS:

        # get crossing group and trajectories

        cg = session.query(models.CrossingGroup).get(cg_id)
        expt = cg.experiment

        trajs = session.query(models.Trajectory).filter_by(
            experiment=expt, odor_state='on', clean=True).all()

        crossing_numbers = []

        for traj in trajs:

            crossing_numbers.append(len(
                session.query(models.Crossing).filter_by(
                    crossing_group=cg, trajectory=traj).all()))

        crossing_numbers_all[cg_id] = crossing_numbers

    # MAKE PLOTS

    fig_size = (6 * AX_GRID[1], 3 * AX_GRID[0])

    fig, axs = plt.subplots(*AX_GRID,
        figsize=fig_size, sharex=True, sharey=True, tight_layout=True)

    for ax, cg_id in zip(axs.flatten(), CROSSING_GROUPS):

        bins = np.arange(-1, np.max(crossing_numbers_all[cg_id])) + 0.5

        ax.hist(crossing_numbers_all[cg_id], bins=bins, lw=0, normed=True)

        ax.set_xlabel('number of crossings')
        ax.set_ylabel('proportion\nof trajectories')
        ax.set_title('{}...'.format(cg_id[:15]))

    for ax in axs.flatten():

        set_fontsize(ax, 16)

    return fig
def main():
    # loop over all experiments
    for expt in session.query(models.Experiment):

        print("In experiment '{}'".format(expt.id))

        dt = 1 / expt.sampling_frequency

        for traj in session.query(models.Trajectory).filter_by(experiment=expt):

            positions = traj.positions(session)
            velocities = traj.velocities(session)

            # calculate kinematic quantities
            velocities_a = kinematics.norm(velocities)

            accelerations = kinematics.acceleration(velocities, dt)
            accelerations_a = kinematics.norm(accelerations)

            headings = kinematics.heading(velocities)

            angular_velocities = kinematics.angular_velocity(velocities, dt)
            angular_velocities_a = kinematics.norm(angular_velocities)

            angular_accelerations = kinematics.acceleration(angular_velocities, dt)
            angular_accelerations_a = kinematics.norm(angular_accelerations)

            distance_from_wall = kinematics.distance_from_wall(positions, WALL_BOUNDS)

            # store kinematic quantities in timepoints
            for ctr, tp in enumerate(traj.timepoints(session)):

                tp.velocity_a = velocities_a[ctr]
                tp.acceleration_x, tp.acceleration_y, tp.acceleration_z = accelerations[ctr]
                tp.acceleration_a = accelerations_a[ctr]

                tp.heading_xy, tp.heading_xz, tp.heading_xyz = headings[ctr]

                tp.angular_velocity_x, tp.angular_velocity_y, tp.angular_velocity_z = angular_velocities[ctr]
                tp.angular_velocity_a = angular_velocities_a[ctr]

                tp.angular_acceleration_x, tp.angular_acceleration_y, tp.angular_acceleration_z = angular_accelerations[ctr]
                tp.angular_acceleration_a = angular_accelerations_a[ctr]

                tp.distance_from_wall = distance_from_wall[ctr]

                session.add(tp)

            commit(session)
def main():

    for expt in session.query(models.Experiment):
        print('In experiment "{}"...'.format(expt.id))

        for odor_state in ODOR_STATES:
            print('Odor state = "{}"'.format(odor_state))

            trajs = session.query(models.Trajectory).\
                filter_by(experiment=expt, odor_state=odor_state, clean=True)

            for variable_name in QUANTITIES:
                print('{}...'.format(variable_name))

                traj_data = []

                traj_ctr = 0
                for traj in trajs:

                    traj_data.extend(traj.timepoint_field(session, variable_name))

                    traj_ctr += 1

                lb, ub = None, None

                if variable_name.endswith('_a') or 'heading' in variable_name:
                    lb = 0
                    if 'heading' in variable_name:
                        ub = 180

                cts, bins = make_distribution(np.array(traj_data), N_BINS, lb=lb, ub=ub)

                file_name = '{}_{}_{}.pickle'.format(expt.id, odor_state, variable_name)

                tp_dstr = models.TimepointDistribution(figure_root_path_env_var=figure_data_env_var,
                                                       directory_path=DIRECTORY_PATH,
                                                       file_name=file_name,
                                                       variable=variable_name,
                                                       experiment_id=expt.id,
                                                       odor_state=odor_state,
                                                       n_data_points=len(traj_data),
                                                       n_trajectories=traj_ctr,
                                                       bin_min=bins[0],
                                                       bin_max=bins[-1],
                                                       n_bins=N_BINS)
                tp_dstr.data = {'cts': cts, 'bins': bins}
                session.add(tp_dstr)

                commit(session)
def main():

    for expt in session.query(models.Experiment):
        print('In experiment "{}"...'.format(expt.id))

        for odor_state in ODOR_STATES:
            print('Odor state = "{}"'.format(odor_state))

            trajs = session.query(models.Trajectory).\
                filter_by(experiment=expt, odor_state=odor_state, clean=True)

            for variable in QUANTITIES:
                print('{}...'.format(variable))

                tp_data = [
                    traj.timepoint_field(session, variable) for traj in trajs
                ]
                n_data_points = np.sum([len(d) for d in tp_data])
                window_len = N_LAGS / expt.sampling_frequency

                acor, p_value, conf_lb, conf_ub = \
                    time_series.xcov_multi_with_confidence(tp_data, tp_data, 0, N_LAGS, normed=True)

                time_vector = np.arange(len(acor)) / expt.sampling_frequency

                file_name = '{}_{}_{}.pickle'.format(expt.id, odor_state,
                                                     variable)

                tp_acor = models.TimepointAutocorrelation(
                    figure_root_path_env_var=figure_data_env_var,
                    directory_path=DIRECTORY_PATH,
                    file_name=file_name,
                    variable=variable,
                    experiment_id=expt.id,
                    odor_state=odor_state,
                    n_data_points=n_data_points,
                    n_trajectories=len(tp_data),
                    window_len=window_len)
                tp_acor.data = {
                    'time_vector': time_vector,
                    'autocorrelation': acor,
                    'p_value': p_value,
                    'confidence_lower': conf_lb,
                    'confidence_upper': conf_ub
                }
                session.add(tp_acor)

                commit(session)
def main():

    for th_ctr in range(2):
        for expt in session.query(models.Experiment):
            print('Experiment "{}"'.format(expt.id))

            threshold_value = THRESHOLD_VALUES[expt.insect][th_ctr]

            # make threshold
            threshold = models.Threshold(experiment=expt,
                                         determination='arbitrary',
                                         value=threshold_value)
            session.add(threshold)

            # loop over odor states
            for odor_state in ODOR_STATES:
                print('Odor "{}"'.format(odor_state))

                # make crossing group
                cg_id = '{}_{}_th{}'.format(expt.id, odor_state, threshold_value)
                cg = models.CrossingGroup(id=cg_id,
                                          experiment=expt,
                                          odor_state=odor_state,
                                          threshold=threshold)
                session.add(cg)

                # get crossings for each trajectory
                for traj in session.query(models.Trajectory).\
                    filter_by(experiment=expt, odor_state=odor_state, clean=True):

                    segments, peaks = time_series.segment_by_threshold(traj.odors(session),
                                                                       threshold_value,
                                                                       traj.timepoint_ids_extended)

                    # add crossings
                    for s_ctr, (segment, peak) in enumerate(zip(segments, peaks)):
                        crossing = models.Crossing(trajectory=traj,
                                                   crossing_number=s_ctr + 1,
                                                   crossing_group=cg)
                        crossing.start_timepoint_id = segment[0]
                        crossing.entry_timepoint_id = segment[1]
                        crossing.peak_timepoint_id = segment[2]
                        crossing.exit_timepoint_id = segment[3] - 1
                        crossing.end_timepoint_id = segment[4] - 1
                        crossing.max_odor = peak
                        session.add(crossing)

                    commit(session)
示例#9
0
def trajectory_duration_distributions(
        EXPERIMENTS, AX_GRID):
    """
    Plot histogram of trajectory lengths.
    """

    traj_durations = {}

    for expt_id in EXPERIMENTS:

        trajs = session.query(models.Trajectory).filter_by(
            experiment_id=expt_id, odor_state='on', clean=True).all()

        traj_durations[expt_id] = [traj.duration for traj in trajs]

    fig_size = (6 * AX_GRID[1], 3 * AX_GRID[0])

    fig, axs = plt.subplots(*AX_GRID,
        figsize=fig_size, sharex=True, sharey=True, tight_layout=True)

    for ax, expt_id in zip(axs.flatten(), EXPERIMENTS):

        ax.hist(traj_durations[expt_id], bins=50, lw=0, normed=True)

        ax.set_xlabel('duration (s)')
        ax.set_ylabel('proportion\nof trajectories')
        ax.set_title(expt_id)

    for ax in axs.flatten():

        set_fontsize(ax, 16)

    return fig
示例#10
0
    def test_segments_correctly_loaded(self):
        trial = session.query(models.Trial).get(1)

        trial.ignored_segments = []
        trial.ignored_segments += [models.IgnoredSegment(start_time=10, end_time=20)]
        trial.ignored_segments += [models.IgnoredSegment(start_time=100, end_time=200)]

        # make sure segments load correctly with one dt
        data, _, _, _ = edr_handling.load_from_trial(trial, dt=.01)

        self.assertEqual(len(data), 3)
        self.assertAlmostEqual(data[0][0, 0], 0, delta=.00001)
        self.assertAlmostEqual(data[0][-1, 0], 9.99, delta=.03)
        self.assertAlmostEqual(data[1][0, 0], 20, delta=.03)
        self.assertAlmostEqual(data[1][-1, 0], 99.99, delta=.03)

        # make sure segments load correctly with another dt
        data, _, _, _ = edr_handling.load_from_trial(trial, dt=.005)

        self.assertEqual(len(data), 3)
        self.assertAlmostEqual(data[0][0, 0], 0, delta=.00001)
        self.assertAlmostEqual(data[0][-1, 0], 9.99, delta=.03)
        self.assertAlmostEqual(data[1][0, 0], 20, delta=.03)
        self.assertAlmostEqual(data[1][-1, 0], 99.99, delta=.03)

        session.rollback()
示例#11
0
def main():

    for expt in session.query(models.Experiment):
        if 'mosquito' in expt.id:
            baseline = MOSQUITO_BASELINE_ODOR
        else:
            baseline = 0

        trajs = session.query(models.Trajectory).filter_by(experiment=expt, clean=True)
        for traj in trajs:
            odor = traj.odors(session)
            integrated_odor = (odor - baseline).sum() / 100
            traj.odor_stats = models.TrajectoryOdorStats(integrated_odor=integrated_odor)

            session.add(traj)

        commit(session)
def main():

    for expt in session.query(models.Experiment):
        print('Experiment "{}"'.format(expt.id))

        threshold = session.query(models.Threshold).\
            filter_by(experiment=expt, determination=DETERMINATION).first()

        # loop over odor states

        for odor_state in ODOR_STATES:
            print('Odor "{}"'.format(odor_state))

            # make crossing group
            cg_id = '{}_{}_th{}_{}'.format(expt.id, odor_state,
                                           threshold.value, DETERMINATION)
            cg = models.CrossingGroup(id=cg_id,
                                      experiment=expt,
                                      odor_state=odor_state,
                                      threshold=threshold)
            session.add(cg)

            # get crossings for each trajectory
            for traj in session.query(models.Trajectory).\
                filter_by(experiment=expt, odor_state=odor_state, clean=True):

                segments, peaks = time_series.segment_by_threshold(
                    traj.odors(session), threshold.value,
                    traj.timepoint_ids_extended)

                # add crossings
                for s_ctr, (segment, peak) in enumerate(zip(segments, peaks)):
                    crossing = models.Crossing(trajectory=traj,
                                               crossing_number=s_ctr + 1,
                                               crossing_group=cg)
                    crossing.start_timepoint_id = segment[0]
                    crossing.entry_timepoint_id = segment[1]
                    crossing.peak_timepoint_id = segment[2]
                    crossing.exit_timepoint_id = segment[3] - 1
                    crossing.end_timepoint_id = segment[4] - 1
                    crossing.max_odor = peak
                    session.add(crossing)

                commit(session)
示例#13
0
def main():

    for expt in session.query(models.Experiment):
        if 'mosquito' in expt.id:
            baseline = MOSQUITO_BASELINE_ODOR
        else:
            baseline = 0

        trajs = session.query(models.Trajectory).filter_by(experiment=expt,
                                                           clean=True)
        for traj in trajs:
            odor = traj.odors(session)
            integrated_odor = (odor - baseline).sum() / 100
            traj.odor_stats = models.TrajectoryOdorStats(
                integrated_odor=integrated_odor)

            session.add(traj)

        commit(session)
def main():

    for expt in session.query(models.Experiment):
        print('In experiment "{}"...'.format(expt.id))

        for odor_state in ODOR_STATES:
            print('Odor state = "{}"'.format(odor_state))

            trajs = session.query(models.Trajectory).\
                filter_by(experiment=expt, odor_state=odor_state, clean=True)

            for variable in QUANTITIES:
                print('{}...'.format(variable))

                tp_data = [traj.timepoint_field(session, variable) for traj in trajs]
                n_data_points = np.sum([len(d) for d in tp_data])
                window_len = N_LAGS / expt.sampling_frequency

                acor, p_value, conf_lb, conf_ub = \
                    time_series.xcov_multi_with_confidence(tp_data, tp_data, 0, N_LAGS, normed=True)

                time_vector = np.arange(len(acor)) / expt.sampling_frequency

                file_name = '{}_{}_{}.pickle'.format(expt.id, odor_state, variable)

                tp_acor = models.TimepointAutocorrelation(figure_root_path_env_var=figure_data_env_var,
                                                          directory_path=DIRECTORY_PATH,
                                                          file_name=file_name,
                                                          variable=variable,
                                                          experiment_id=expt.id,
                                                          odor_state=odor_state,
                                                          n_data_points=n_data_points,
                                                          n_trajectories=len(tp_data),
                                                          window_len=window_len)
                tp_acor.data = {'time_vector': time_vector,
                                'autocorrelation': acor,
                                'p_value': p_value,
                                'confidence_lower': conf_lb,
                                'confidence_upper': conf_ub}
                session.add(tp_acor)

                commit(session)
示例#15
0
def main():

    for insect in INSECTS:
        cleaning_params_list = session.query(models.TrajectoryCleaningParameter.param,
                                             models.TrajectoryCleaningParameter.value).\
                                             filter_by(insect=insect).all()
        cleaning_params = dict(cleaning_params_list)

        for expt in session.query(models.Experiment).filter_by(insect=insect):
            for traj in expt.trajectories:

                clean_portions = clean_traj(traj, cleaning_params)

                for ctr, clean_portion in enumerate(clean_portions):

                    if clean_portion[
                            0] == traj.start_timepoint_id and clean_portion[
                                1] == traj.end_timepoint_id:
                        traj.clean = True
                        portion_traj = traj

                    else:
                        stp_id, etp_id = clean_portion
                        # make new trajectory
                        id = traj.id + '_c{}'.format(ctr)
                        portion_traj = models.Trajectory(
                            id=id,
                            start_timepoint_id=stp_id,
                            end_timepoint_id=etp_id,
                            experiment=expt,
                            raw=False,
                            clean=True,
                            odor_state=traj.odor_state)
                    session.add(portion_traj)
                    portion_traj.basic_info = make_trajectory_basic_info(
                        portion_traj)
                    session.add(portion_traj)

                commit(session)
示例#16
0
def main():
    instructions = ('Right click on trial to mark starts and ends of segments.\n'
                    'Close window to move on to next trial.')
    print(instructions)

    # get trials
    trials = session.query(models.Trial). \
        filter(models.Trial.experiment_id == EXPERIMENT_ID). \
        filter(models.Trial.recording_start >= EARLIEST_DATETIME)

    for trial in trials[:N_TRIALS]:
        fig = plt.figure(facecolor='white', figsize=FIG_SIZE)
        fig, axs, edr_data = plot_trial_basic(trial, fig, cols=COLS, dt=.02)

        # define previously specified ignored segments if there are any
        if trial.ignored_segments:
            ignored_segments_simple = [(i_s.start_time, i_s.end_time) for i_s in trial.ignored_segments]
        else:
            ignored_segments_simple = None

        t_min = edr_data[0, 0]
        t_max = edr_data[-1, 0]

        # create segment selector
        segment_selector = SegmentSelector(fig, axs, t_min, t_max, segments=ignored_segments_simple)

        plt.show(block=True)

        # remove all ignored segments that were previously bound to this trial
        [session.delete(i_s) for i_s in trial.ignored_segments]
        trial.ignored_segments = []

        # add all ignored segments that we've created in the segment selector
        for segment in segment_selector.segments_simple:
            ignored_segment = models.IgnoredSegment()
            # convert time to time idx
            ignored_segment.start_time = segment[0]
            ignored_segment.end_time = segment[1]

            trial.ignored_segments += [ignored_segment]

    save = raw_input('Save [y or n]?')

    if save.lower() == 'y':
        # add all updated trials to database
        for trial in trials:
            session.add(trial)
        session.commit()
    else:
        print('Data not saved.')
        session.rollback()
示例#17
0
def main(traj_limit=None):

    # add script execution to infotaxis database
    add_script_execution(script_id=SCRIPT_ID,
                         notes=SCRIPT_NOTES,
                         session=session)
    session.commit()

    for experiment_id in EXPERIMENT_IDS:
        print experiment_id
        for odor_state in ODOR_STATES:

            # make geom_config_group
            geom_config_group_id = '{}_{}_odor_{}'.format(
                GEOM_CONFIG_GROUP_ID, experiment_id, odor_state)
            geom_config_group_desc = GEOM_CONFIG_GROUP_DESC.format(
                experiment_id, odor_state)
            geom_config_group = models.GeomConfigGroup(
                id=geom_config_group_id, description=geom_config_group_desc)

            # get all wind tunnel trajectories of interest
            trajs = wt_session.query(wt_models.Trajectory).\
                filter_by(
                    experiment_id=experiment_id,
                    odor_state=odor_state,
                    clean=True
                )

            for tctr, traj in enumerate(trajs):

                positions = traj.positions(wt_session)

                discrete_trajectory = ENV.discretize_position_sequence(
                    positions)
                discrete_duration = len(discrete_trajectory)
                avg_dt = .01 * len(positions) / discrete_duration
                geom_config = models.GeomConfig(duration=discrete_duration)
                geom_config.start_idx = discrete_trajectory[0]
                geom_config.geom_config_group = geom_config_group

                # add extension containing extra data about this geom_config
                ext = models.GeomConfigExtensionRealTrajectory(
                    real_trajectory_id=traj.id, avg_dt=avg_dt)
                geom_config.extension_real_trajectory = ext

                if traj_limit and (tctr == traj_limit - 1):

                    break

            session.add(geom_config_group)
            session.commit()
示例#18
0
    def test_segments_correctly_loaded_with_beginning_ignored_segments(self):
        trial = session.query(models.Trial).get(1)

        trial.ignored_segments = []
        trial.ignored_segments += [models.IgnoredSegment(start_time=0, end_time=20)]
        trial.ignored_segments += [models.IgnoredSegment(start_time=100, end_time=200)]

        data, _, _, _ = edr_handling.load_from_trial(trial, dt=.01)

        self.assertEqual(len(data), 2)

        session.rollback()

        trial = session.query(models.Trial).get(1)

        trial.ignored_segments = []
        trial.ignored_segments += [models.IgnoredSegment(start_time=0.05, end_time=20)]
        trial.ignored_segments += [models.IgnoredSegment(start_time=100, end_time=200)]

        data, _, _, _ = edr_handling.load_from_trial(trial, dt=.01)

        self.assertEqual(len(data), 2)

        session.rollback()

        trial = session.query(models.Trial).get(1)

        trial.ignored_segments = []
        trial.ignored_segments += [models.IgnoredSegment(start_time=0.11, end_time=20)]
        trial.ignored_segments += [models.IgnoredSegment(start_time=100, end_time=200)]

        data, _, _, _ = edr_handling.load_from_trial(trial, dt=.01)

        self.assertEqual(len(data), 3)

        session.rollback()
def main(traj_limit=None):

    # add script execution to infotaxis database
    add_script_execution(script_id=SCRIPT_ID, notes=SCRIPT_NOTES, session=session)
    session.commit()

    for experiment_id in EXPERIMENT_IDS:
        print experiment_id
        for odor_state in ODOR_STATES:

            # make geom_config_group
            geom_config_group_id = '{}_{}_odor_{}'.format(GEOM_CONFIG_GROUP_ID, experiment_id, odor_state)
            geom_config_group_desc = GEOM_CONFIG_GROUP_DESC.format(experiment_id, odor_state)
            geom_config_group = models.GeomConfigGroup(id=geom_config_group_id,
                                                       description=geom_config_group_desc)

            # get all wind tunnel trajectories of interest
            trajs = wt_session.query(wt_models.Trajectory).\
                filter_by(
                    experiment_id=experiment_id,
                    odor_state=odor_state,
                    clean=True
                )

            for tctr, traj in enumerate(trajs):

                positions = traj.positions(wt_session)

                discrete_trajectory = ENV.discretize_position_sequence(positions)
                discrete_duration = len(discrete_trajectory)
                avg_dt = .01 * len(positions) / discrete_duration
                geom_config = models.GeomConfig(duration=discrete_duration)
                geom_config.start_idx = discrete_trajectory[0]
                geom_config.geom_config_group = geom_config_group

                # add extension containing extra data about this geom_config
                ext = models.GeomConfigExtensionRealTrajectory(real_trajectory_id=traj.id,
                                                               avg_dt=avg_dt)
                geom_config.extension_real_trajectory = ext

                if traj_limit and (tctr == traj_limit - 1):

                    break

            session.add(geom_config_group)
            session.commit()
def main():
    # get or create experiment
    experiment = get_or_create(session, models.Experiment,
                               id=EXPERIMENT_ID,
                               directory_path=EXPERIMENT_DIRECTORY_PATH)

    full_directory_path = os.path.join(ARENA_DATA_DIRECTORY, EXPERIMENT_DIRECTORY_PATH)

    edr_files = [file_name for file_name in os.listdir(full_directory_path)
                 if file_name.lower().endswith('.edr')]

    # get all trials and add them to database
    for file_name in edr_files:
        print('Attempting to load file "{}"'.format(file_name))

        # skip if file already added
        if session.query(models.Trial).filter_by(file_name=file_name).first():
            print('Skipping file "{}" because it is already in the database.'.format(file_name))
            continue

        try:
            file_path = os.path.join(full_directory_path, file_name)
            _, recording_start, _, header = edr_handling.load_edr(file_path)
            recording_duration = header['recording_duration']

            # get insect number from file name using regex
            insect_number = re.findall(INSECT_NUMBER_EXPRESSION, file_name)[0]

            insect_id = '{}_{}'.format(recording_start.strftime('%Y%m%d'), insect_number)

            # get/create insect
            insect = get_or_create(session, models.Insect, id=insect_id)

            trial = models.Trial(file_name=file_name,
                                 recording_start=recording_start,
                                 recording_duration=recording_duration)

            trial.insect = insect
            trial.experiment = experiment

            session.add(trial)

        except Exception, e:
            print('Error: "{}"'.format(e))
示例#21
0
def show_distributions(EXPERIMENTS, VARIABLES, ODOR_STATES, AX_GRID):

    AX_SIZE = (6, 4)
    LW = 2
    COLORS = ('b', 'g', 'r')

    fig_size = (AX_SIZE[0] * AX_GRID[1], AX_SIZE[1] * AX_GRID[0])

    fig, axs = plt.subplots(*AX_GRID, figsize=fig_size, tight_layout=True)

    for ax, (expt_id, variable) in zip(axs.flatten(),
                                       cproduct(EXPERIMENTS, VARIABLES)):

        handles = []

        for odor_state, color in zip(ODOR_STATES, COLORS):

            tp_dstr = session.query(models.TimepointDistribution).filter_by(
                variable=variable,
                experiment_id=expt_id,
                odor_state=odor_state).first()

            handles.append(
                ax.plot(tp_dstr.bincs,
                        tp_dstr.cts,
                        lw=LW,
                        color=color,
                        label=odor_state)[0])

        ax.set_xlabel(variable)
        ax.set_ylabel('counts')

        ax.legend(handles=handles)

        ax.set_title('{}\n{}'.format(expt_id, variable))

    for ax in axs.flatten():

        set_font_size(ax, 16)

    return fig
def get_trajs_with_integrated_odor_above_threshold(experiment_id, odor_state, integrated_odor_threshold, max_trajs=np.inf):
    """
    Return all trajectories from a given experiment/odor state that have a certain minimum odor.
    :param experiment_id: experiment id
    :param odor_state: odor state
    :param integrated_odor_threshold: threshold
    :return: list of trajectories
    """
    trajs_all = session.query(models.Trajectory).filter(
        models.Trajectory.experiment_id == experiment_id,
        models.Trajectory.odor_state == odor_state,
        models.Trajectory.clean,
    )
    trajs = []
    for traj in trajs_all:
        if traj.odor_stats.integrated_odor > integrated_odor_threshold:
            trajs.append(traj)
        if len(trajs) >= max_trajs:
            break

    return trajs
def show_distributions(EXPERIMENTS, VARIABLES, ODOR_STATES, AX_GRID):

    AX_SIZE = (6, 4)
    LW = 2
    COLORS = ('b', 'g', 'r')

    fig_size = (AX_SIZE[0] * AX_GRID[1], AX_SIZE[1] * AX_GRID[0])

    fig, axs = plt.subplots(*AX_GRID,
                            figsize=fig_size,
                            tight_layout=True)

    for ax, (expt_id, variable) in zip(axs.flatten(), cproduct(EXPERIMENTS, VARIABLES)):

        handles = []

        for odor_state, color in zip(ODOR_STATES, COLORS):

            tp_dstr = session.query(models.TimepointDistribution).filter_by(
                variable=variable, experiment_id=expt_id,
                odor_state=odor_state).first()

            handles.append(ax.plot(
                tp_dstr.bincs, tp_dstr.cts, lw=LW, color=color, label=odor_state)[0])

        ax.set_xlabel(variable)
        ax.set_ylabel('counts')

        ax.legend(handles=handles)

        ax.set_title('{}\n{}'.format(expt_id, variable))

    for ax in axs.flatten():

        set_font_size(ax, 16)

    return fig
import matplotlib.cm as cm
from db_api import models
from db_api.connect import session

ODOR_STATES = ('on', 'none', 'afterodor')
THRESHOLDS = {'fruit_fly': (0.01, 0.1), 'mosquito': (401, 410)}
DTH = 0.0001
TIMEPOINTS_BEFORE_ENTRY = 50
TIMEPOINTS_AFTER_EXIT = 50

FACE_COLOR = 'white'
FIG_SIZE = (8, 10)
LW = 2
plt.ion()

expts = session.query(models.Experiment).all()
keep_going = True
e_ctr = 0
o_ctr = 0
th_ctr = 0

fig, axs = plt.subplots(3,
                        1,
                        facecolor=FACE_COLOR,
                        figsize=FIG_SIZE,
                        tight_layout=True)
axs[2].twin = axs[2].twinx()

while keep_going:

    # get new crossing group
示例#25
0
def main():

    for expt in session.query(models.Experiment):
        threshold = session.query(models.Threshold).filter_by(
            experiment=expt, determination=DETERMINATION).first()
        for cg in threshold.crossing_groups:
            print(cg.id)
            for crossing in cg.crossings:

                position_x_entry = crossing.timepoint_field(
                    session, 'position_x', 0, 0, 'entry', 'entry')[0]
                position_y_entry = crossing.timepoint_field(
                    session, 'position_y', 0, 0, 'entry', 'entry')[0]
                position_z_entry = crossing.timepoint_field(
                    session, 'position_z', 0, 0, 'entry', 'entry')[0]

                position_x_peak = crossing.timepoint_field(
                    session, 'position_x', 0, 0, 'peak', 'peak')[0]
                position_y_peak = crossing.timepoint_field(
                    session, 'position_y', 0, 0, 'peak', 'peak')[0]
                position_z_peak = crossing.timepoint_field(
                    session, 'position_z', 0, 0, 'peak', 'peak')[0]

                position_x_exit = crossing.timepoint_field(
                    session, 'position_x', 0, 0, 'exit', 'exit')[0]
                position_y_exit = crossing.timepoint_field(
                    session, 'position_y', 0, 0, 'exit', 'exit')[0]
                position_z_exit = crossing.timepoint_field(
                    session, 'position_z', 0, 0, 'exit', 'exit')[0]

                heading_xy_entry = crossing.timepoint_field(
                    session, 'heading_xy', 0, 0, 'entry', 'entry')[0]
                heading_xz_entry = crossing.timepoint_field(
                    session, 'heading_xz', 0, 0, 'entry', 'entry')[0]
                heading_xyz_entry = crossing.timepoint_field(
                    session, 'heading_xyz', 0, 0, 'entry', 'entry')[0]

                heading_xy_peak = crossing.timepoint_field(
                    session, 'heading_xy', 0, 0, 'peak', 'peak')[0]
                heading_xz_peak = crossing.timepoint_field(
                    session, 'heading_xz', 0, 0, 'peak', 'peak')[0]
                heading_xyz_peak = crossing.timepoint_field(
                    session, 'heading_xyz', 0, 0, 'peak', 'peak')[0]

                heading_xy_exit = crossing.timepoint_field(
                    session, 'heading_xy', 0, 0, 'exit', 'exit')[0]
                heading_xz_exit = crossing.timepoint_field(
                    session, 'heading_xz', 0, 0, 'exit', 'exit')[0]
                heading_xyz_exit = crossing.timepoint_field(
                    session, 'heading_xyz', 0, 0, 'exit', 'exit')[0]

                crossing.feature_set_basic = models.CrossingFeatureSetBasic(
                    position_x_entry=position_x_entry,
                    position_y_entry=position_y_entry,
                    position_z_entry=position_z_entry,
                    position_x_peak=position_x_peak,
                    position_y_peak=position_y_peak,
                    position_z_peak=position_z_peak,
                    position_x_exit=position_x_exit,
                    position_y_exit=position_y_exit,
                    position_z_exit=position_z_exit,
                    heading_xy_entry=heading_xy_entry,
                    heading_xz_entry=heading_xz_entry,
                    heading_xyz_entry=heading_xyz_entry,
                    heading_xy_peak=heading_xy_peak,
                    heading_xz_peak=heading_xz_peak,
                    heading_xyz_peak=heading_xyz_peak,
                    heading_xy_exit=heading_xy_exit,
                    heading_xz_exit=heading_xz_exit,
                    heading_xyz_exit=heading_xyz_exit)

                session.add(crossing)
                commit(session)
示例#26
0
THRESHOLD_ID = 4  # look up in threshold table
DISCRIMINATION_THRESHOLD = 450
DISPLAY_START = -50
DISPLAY_END = 150
INTEGRAL_START = 0  # timepoints (fs 100 Hz) relative to peak
INTEGRAL_END = 100
VARIABLE = 'heading_xyz'

FACE_COLOR = 'white'
FIG_SIZE = (10, 10)
LW = 2
COLORS = ['k', 'r']  # [below, above] discrimination threshold

# get threshold and crossing group for odor on trajectories
threshold = session.query(models.Threshold).get(THRESHOLD_ID)
cg = session.query(models.CrossingGroup).\
    filter_by(threshold=threshold, odor_state='on').first()

all_crossings = session.query(models.Crossing).\
    filter(models.Crossing.crossing_group == cg).all()

# get all crossings where max odor is below/above discrimination threshold
crossings_below = session.query(models.Crossing).\
    filter(models.Crossing.crossing_group == cg).\
    filter(models.Crossing.max_odor < DISCRIMINATION_THRESHOLD).all()
crossings_above = session.query(models.Crossing).\
    filter(models.Crossing.crossing_group == cg).\
    filter(models.Crossing.max_odor >= DISCRIMINATION_THRESHOLD).all()

# make array for storing crossing-specific time-series
示例#27
0
                              figsize=FIG_SIZE_POS,
                              facecolor=FACE_COLOR,
                              sharex=True,
                              sharey=True,
                              tight_layout=True)
fig_pos, axs_pos = plt.subplots(4,
                                1,
                                figsize=FIG_SIZE_POS,
                                facecolor=FACE_COLOR,
                                sharex=True,
                                tight_layout=True)

wind_speed_handles = []
wind_speed_labels = []

for expt_ctr, expt in enumerate(session.query(models.Experiment)):

    print(expt.id)

    if 'fruitfly' in expt.id:

        wind_speed_labels.append('f {} m/s'.format(expt.wind_speed))

    elif 'mosquito' in expt.id:

        wind_speed_labels.append('m {} m/s'.format(expt.wind_speed))

    subset_handles = []
    subset_labels = []
    subset_handles_pos = []
    subset_labels_pos = []
示例#28
0
X_LIM = 0, 30
Y_LIM = 40, 130

fig, axs = plt.subplots(2,
                        3,
                        facecolor='white',
                        figsize=FIG_SIZE,
                        tight_layout=True)

for s_ctr, sg_id_template in enumerate(
    (SEGMENT_GROUP_ID_EMPIRICAL, SEGMENT_GROUP_ID_INFOTAXIS)):
    for e_ctr, expt in enumerate(EXPERIMENTS):
        for o_ctr, odor_state in enumerate(ODOR_STATES):

            sg_id = sg_id_template.format(expt, odor_state)
            seg_group = session.query(models.SegmentGroup).get(sg_id)

            heading_ensemble = None
            for ens in seg_group.analysis_triggered_ensembles:
                # only get the heading ensemble if it has no conditions
                if ens.variable == 'heading' and ens.trigger_start == 'exit':
                    heading_ensemble = ens
                    break

            heading_ensemble.fetch_data(session)
            if heading_ensemble._data is None:
                continue

            time_vector = np.arange(len(heading_ensemble.mean))

            ax = axs[s_ctr, o_ctr]
def optimize_model_params(
        SEED,
        DURATION, DT, BOUNDS,
        EXPERIMENT, ODOR_STATE,
        MAX_TRAJS_EMPIRICAL,
        N_TIME_POINTS_EMPIRICAL,
        SAVE_FILE_PREFIX,
        INITIAL_PARAMS, MAX_ITERS):
    """
    Find optimal model parameters by fitting speed and angular velocity distributions of empirical
    data.
    """

    # check to see if empirical time points have already been saved

    file_name = '{}_{}_odor_{}.npy'.format(SAVE_FILE_PREFIX, EXPERIMENT, ODOR_STATE)

    if os.path.isfile(file_name):

        empirical = np.load(file_name)[0]

    else:

        print('extracting time points from data')

        # get all trajectories

        trajs = session.query(models.Trajectory).filter_by(
            experiment_id=EXPERIMENT, odor_state=ODOR_STATE, clean=True).\
            limit(MAX_TRAJS_EMPIRICAL).all()

        # get all speeds and angular velocities

        cc = np.concatenate
        speeds_empirical = cc([traj.velocities_a(session) for traj in trajs])
        ws_empirical = cc([traj.angular_velocities_a(session) for traj in trajs])
        ys_empirical = cc([traj.timepoint_field(session, 'position_y') for traj in trajs])

        # sample a set of speeds and ws

        np.random.seed(SEED)

        speeds_empirical = np.random.choice(speeds_empirical, N_TIME_POINTS_EMPIRICAL, replace=False)
        ws_empirical = np.random.choice(ws_empirical, N_TIME_POINTS_EMPIRICAL, replace=False)
        ys_empirical = np.random.choice(ys_empirical, N_TIME_POINTS_EMPIRICAL, replace=False)

        empirical = {'speeds': speeds_empirical, 'ws': ws_empirical, 'ys': ys_empirical}

        # save them for easy access next time

        np.save(file_name, np.array([empirical]))

    print('performing optimization')

    # make a plume

    pl = GaussianLaminarPlume(0, np.zeros((2,)), np.ones((2,)))

    # define function to be optimized

    def optim_fun(p):

        np.random.seed(SEED)

        start_pos = np.array([
            np.random.uniform(*BOUNDS[0]),
            np.random.uniform(*BOUNDS[1]),
            np.random.uniform(*BOUNDS[2]),
        ])

        # make agent and trajectory

        ag = CenterlineInferringAgent(
            tau=p[0], noise=p[1], bias=p[2], threshold=np.inf,
            hit_trigger='peak', hit_influence=0,
            k_0=np.eye(2), k_s=np.eye(2), tau_memory=1, bounds=BOUNDS)

        traj = ag.track(pl, start_pos, DURATION, DT)

        speeds = np.linalg.norm(traj['vs'], axis=1)
        ws = np.linalg.norm(angular_velocity(traj['vs'], DT), axis=1)
        ws = ws[~np.isnan(ws)]
        ys = traj['xs'][:, 1]

        ks_speeds = stats.ks_2samp(speeds, empirical['speeds'])[0]
        ks_ws = stats.ks_2samp(ws, empirical['ws'])[0]
        ks_ys = stats.ks_2samp(ys, empirical['ys'])[0]

        val = ks_speeds + ks_ws + ks_ys

        # punish unallowable values

        if np.any(p < 0):

            val += 10000

        return val

    # optimize it

    p_best = optimize.fmin(optim_fun, np.array(INITIAL_PARAMS), maxiter=MAX_ITERS)

    # generate one final trajectory

    np.random.seed(SEED)

    start_pos = np.array([
        np.random.uniform(*BOUNDS[0]),
        np.random.uniform(*BOUNDS[1]),
        np.random.uniform(*BOUNDS[2]),
    ])

    ag = CenterlineInferringAgent(
        tau=p_best[0], noise=p_best[1], bias=p_best[2], threshold=np.inf,
        hit_trigger='peak', hit_influence=0,
        k_0=np.eye(2), k_s=np.eye(2), tau_memory=1, bounds=BOUNDS)

    traj = ag.track(pl, start_pos, DURATION, DT)

    speeds = np.linalg.norm(traj['vs'], axis=1)
    ws = np.linalg.norm(angular_velocity(traj['vs'], DT), axis=1)
    ws = ws[~np.isnan(ws)]
    ys = traj['xs'][:, 1]

    # make plots of things that have been optimized

    ## get bins

    speed_max = max(speeds.max(), empirical['speeds'].max())
    bins_speed = np.linspace(0, speed_max, 41, endpoint=True)
    bincs_speed = 0.5 * (bins_speed[:-1] + bins_speed[1:])

    w_max = max(ws.max(), empirical['ws'].max())
    bins_w = np.linspace(0, w_max, 41, endpoint=True)
    bincs_w = 0.5 * (bins_w[:-1] + bins_w[1:])

    bins_y = np.linspace(BOUNDS[1][0], BOUNDS[1][1], 41, endpoint=True)
    bincs_y = 0.5 * (bins_y[:-1] + bins_y[1:])

    cts_speed, _ = np.histogram(speeds, bins=bins_speed, normed=True)
    cts_speed_empirical, _ = np.histogram(empirical['speeds'], bins=bins_speed, normed=True)

    cts_w, _ = np.histogram(ws, bins=bins_w, normed=True)
    cts_w_empirical, _ = np.histogram(empirical['ws'], bins=bins_w, normed=True)

    cts_y, _ = np.histogram(ys, bins=bins_y, normed=True)
    cts_y_empirical, _ = np.histogram(empirical['ys'], bins=bins_y, normed=True)

    fig = plt.figure(figsize=(15, 8), tight_layout=True)
    axs = []

    axs.append(fig.add_subplot(2, 3, 1))
    axs.append(fig.add_subplot(2, 3, 2))
    axs.append(fig.add_subplot(2, 3, 3))

    axs[0].plot(bincs_speed, cts_speed_empirical, lw=2, color='k')
    axs[0].plot(bincs_speed, cts_speed, lw=2, color='r')

    axs[0].set_xlabel('speed (m/s)')
    axs[0].set_ylabel('rel. counts')

    axs[0].legend(['data', 'model'], fontsize=16)

    axs[1].plot(bincs_w, cts_w_empirical, lw=2, color='k')
    axs[1].plot(bincs_w, cts_w, lw=2, color='r')

    axs[1].set_xlabel('ang. vel. (rad/s)')

    axs[2].plot(bincs_y, cts_y_empirical, lw=2, color='k')
    axs[2].plot(bincs_y, cts_y, lw=2, color='r')

    axs[2].set_xlabel('y (m)')

    axs.append(fig.add_subplot(2, 1, 2))

    axs[3].plot(traj['xs'][:500, 0], traj['xs'][:500, 1], lw=2, color='k', zorder=0)
    axs[3].scatter(traj['xs'][0, 0], traj['xs'][0, 1], lw=0, c='r', zorder=1, s=100)

    axs[3].set_xlim(*BOUNDS[0])
    axs[3].set_ylim(*BOUNDS[1])

    axs[3].set_xlabel('x (m)')
    axs[3].set_ylabel('y (m)')

    axs[3].set_title('example trajectory')

    for ax in axs:

        set_font_size(ax, 16)

    # print out parameters

    print('best params:')
    print('tau = {}'.format(p_best[0]))
    print('noise = {}'.format(p_best[1]))
    print('bias = {}'.format(p_best[2]))

    return fig
Y_LIM = 40, 150


figs = []
axss = []
for _ in range(3):
    fig, axs = plt.subplots(2, 3, facecolor=FACE_COLOR, figsize=FIG_SIZE, tight_layout=True)
    figs += [fig]
    axss += [axs]

for s_ctr, sg_id_template in enumerate((SEGMENT_GROUP_ID_EMPIRICAL, SEGMENT_GROUP_ID_INFOTAXIS)):
    for e_ctr, expt in enumerate(EXPERIMENTS):
        for o_ctr, odor_state in enumerate(ODOR_STATES):
            for c in ('early', 'late'):
                sg_id = sg_id_template.format(expt, odor_state)
                seg_group = session.query(models.SegmentGroup).get(sg_id)

                heading_ensemble = None
                for ens in seg_group.analysis_triggered_ensembles:
                    # only get the heading ensemble if it has correct conditions
                    conditions = (ens.variable == 'heading',
                                  ens.trigger_start == 'exit',
                                  59 < ens.heading_min < 61,
                                  119 < ens.heading_max < 121,
                                  ens.x_idx_max == 50,
                                  ens.x_idx_min == 15)

                    if c == 'early':
                        conditions += (ens.encounter_number_min == 1,
                                       ens.encounter_number_max == 2,)
                    elif c == 'late':
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from db_api import models
from db_api.connect import session

ODOR_STATES = ('on', 'none', 'afterodor')
TIMEPOINTS_BEFORE_ENTRY = 50
TIMEPOINTS_AFTER_EXIT = 50

FACE_COLOR = 'white'
FIG_SIZE = (8, 10)
LW = 2
plt.ion()


expts = session.query(models.Experiment).all()
keep_going = True
e_ctr = 0
o_ctr = 0

fig, axs = plt.subplots(3, 1, facecolor=FACE_COLOR, figsize=FIG_SIZE, tight_layout=True)
axs[2].twin = axs[2].twinx()

while keep_going:

    # get new crossing group
    expt = expts[e_ctr]
    odor_state = ODOR_STATES[o_ctr]
    threshold = session.query(models.Threshold).\
        filter_by(experiment=expt, determination='chosen')
示例#32
0
def main():

    n_timesteps = TIME_AVG_END - TIME_AVG_START

    for expt in session.query(models.Experiment):
        for cg in session.query(models.CrossingGroup).\
            filter(models.CrossingGroup.experiment == expt).\
            filter(models.CrossingGroup.odor_state == 'on').\
            filter(models.Threshold.determination == 'arbitrary'):
            print('Crossings group: "{}"'.format(cg.id))

            for th_val in DISCRIMINATION_THRESHOLD_VALUES[expt.insect]:

                crossings_below = session.query(models.Crossing).\
                    filter(models.Crossing.crossing_group == cg).\
                    filter(models.Crossing.max_odor < th_val).all()
                crossings_above = session.query(models.Crossing).\
                    filter(models.Crossing.crossing_group == cg).\
                    filter(models.Crossing.max_odor >= th_val).all()

                responses_below = np.nan * np.ones((len(crossings_below), n_timesteps), dtype=float)
                responses_above = np.nan * np.ones((len(crossings_above), n_timesteps), dtype=float)

                # fill in values
                for crossing, response in zip(crossings_below, responses_below):
                    response_var = crossing.timepoint_field(session, RESPONSE_VAR,
                                                            first=TIME_AVG_START,
                                                            last=TIME_AVG_END - 1,
                                                            first_rel_to=TIME_AVG_REL_TO,
                                                            last_rel_to=TIME_AVG_REL_TO)
                    response[:len(response_var)] = response_var

                for crossing, response in zip(crossings_above, responses_above):
                    response_var = crossing.timepoint_field(session, RESPONSE_VAR,
                                                            first=TIME_AVG_START,
                                                            last=TIME_AVG_END - 1,
                                                            first_rel_to=TIME_AVG_REL_TO,
                                                            last_rel_to=TIME_AVG_REL_TO)
                    response[:len(response_var)] = response_var

                diff, lb, ub = get_time_avg_response_diff_and_bounds(responses_below,
                                                                     responses_above)

                if len(crossings_below) == 0 or len(crossings_above) == 0:
                    diff = None
                    lb = None
                    ub = None

                disc_th = models.DiscriminationThreshold(crossing_group=cg,
                                                         odor_threshold=th_val,
                                                         n_crossings_below=len(crossings_below),
                                                         n_crossings_above=len(crossings_above),
                                                         time_avg_start=TIME_AVG_START,
                                                         time_avg_end=TIME_AVG_END,
                                                         time_avg_rel_to=TIME_AVG_REL_TO,
                                                         variable=RESPONSE_VAR,
                                                         time_avg_difference=diff,
                                                         lower_bound=lb,
                                                         upper_bound=ub)

                session.add(disc_th)
                commit(session)
示例#33
0
}

fig, axs_unflat = plt.subplots(3,
                               2,
                               facecolor=FACE_COLOR,
                               figsize=FIG_SIZE,
                               sharex=True,
                               tight_layout=True)

axs = axs_unflat.flatten()
axs_twin = [ax.twinx() for ax in axs[2:]]

wind_speed_handles = []
wind_speed_labels = []

for expt in session.query(models.Experiment):

    if 'fly' in expt.id:

        wind_speed_label = 'fly {} m/s'.format(expt.wind_speed)

    elif 'mosquito' in expt.id:

        wind_speed_label = 'mosquito {} m/s'.format(expt.wind_speed)

    print(expt.id)

    # get threshold and crossing group for this experiment
    threshold = session.query(models.Threshold).filter_by(
        experiment=expt, determination=DETERMINATION).first()
THRESHOLD_ID = 4  # look up in threshold table
DISCRIMINATION_THRESHOLD = 450
DISPLAY_START = -50
DISPLAY_END = 150
INTEGRAL_START = 0  # timepoints (fs 100 Hz) relative to peak
INTEGRAL_END = 100
VARIABLE = 'heading_xyz'

FACE_COLOR = 'white'
FIG_SIZE = (10, 10)
LW = 2
COLORS = ['k', 'r']  # [below, above] discrimination threshold


# get threshold and crossing group for odor on trajectories
threshold = session.query(models.Threshold).get(THRESHOLD_ID)
cg = session.query(models.CrossingGroup).\
    filter_by(threshold=threshold, odor_state='on').first()

all_crossings = session.query(models.Crossing).\
    filter(models.Crossing.crossing_group == cg).all()

# get all crossings where max odor is below/above discrimination threshold
crossings_below = session.query(models.Crossing).\
    filter(models.Crossing.crossing_group == cg).\
    filter(models.Crossing.max_odor < DISCRIMINATION_THRESHOLD).all()
crossings_above = session.query(models.Crossing).\
    filter(models.Crossing.crossing_group == cg).\
    filter(models.Crossing.max_odor >= DISCRIMINATION_THRESHOLD).all()

# make array for storing crossing-specific time-series
示例#35
0
def response_discrimination_via_thresholds(
        RESPONSE_VAR,
        CROSSING_GROUP_IDS,
        CROSSING_GROUP_LABELS,
        CROSSING_GROUP_X_LIMS,
        AX_SIZE, FONT_SIZE):
    """
    Plot the difference between the mean time-averaged headings for plume-crossings
    above and below an odor concentration threshold as a function of that threshold.
    """

    concentration_factor = 0.0476 / 526

    fig, axs = plt.subplots(
        2, 2, facecolor='white',
        figsize=(7.5, 5), tight_layout=True)

    axs = axs.flatten()

    for cg_id, ax in zip(CROSSING_GROUP_IDS, axs):

        cg = session.query(models.CrossingGroup).filter_by(id=cg_id).first()

        disc_ths = session.query(models.DiscriminationThreshold).\
            filter_by(crossing_group=cg, variable=RESPONSE_VAR)

        ths = np.array([disc_th.odor_threshold for disc_th in disc_ths])

        if 'fly' in cg_id:

            ths *= concentration_factor

        means = np.array([disc_th.time_avg_difference for disc_th in disc_ths], dtype=float)
        lbs = np.array([disc_th.lower_bound for disc_th in disc_ths], dtype=float)
        ubs = np.array([disc_th.upper_bound for disc_th in disc_ths], dtype=float)

        ax.plot(ths, means, color='k', lw=3)
        ax.fill_between(ths, lbs, ubs, color='k', alpha=.3)

        ax.set_xlim(CROSSING_GROUP_X_LIMS[cg.id])

        if 'fly' in cg_id:

            ax.set_xlabel('threshold (% ethanol)')

        else:

            ax.set_xlabel('threshold (ppm CO2)')

        for xtl in ax.get_xticklabels():

            xtl.set_rotation(60)

        ax.set_ylabel('mean heading diff.\nbetween groups')
        ax.set_title(CROSSING_GROUP_LABELS[cg.id])

    for ax in axs:

        set_fontsize(ax, FONT_SIZE)

    return fig
示例#36
0
                               above_z, below_z])

    return np.min(dist_all_walls, axis=0)


plt.ion()
fig, axs = plt.subplots(4, 1, facecolor=FACE_COLOR, figsize=FIG_SIZE, tight_layout=True)
axs[3].twin = axs[3].twinx()

# loop through odor states
for odor_state in ODOR_STATES:
    print('Current odor state: {}'.format(odor_state))
    sample_group = '_'.join([EXPERIMENT_ID, 'odor', odor_state])

    trajs = session.query(models.Trajectory).filter_by(experiment_id=EXPERIMENT_ID,
                                                       odor_state=odor_state,
                                                       raw=True).all()

    # loop through trajectories
    t_ctr = 0
    n_trajs = len(trajs)

    if trajs:
        keep_going = True
    else:
        keep_going = False

    while keep_going:

        t_ctr %= n_trajs
示例#37
0
row_labels = ('0.3 m/s', '0.4 m/s', '0.6 m/s')
col_labels = ('on', 'none', 'afterodor')

for sim_id_template in (SIMULATION_ID_EMPIRICAL, SIMULATION_ID_INFOTAXIS):
    fig, axs = plt.subplots(3,
                            3,
                            facecolor='white',
                            figsize=FIG_SIZE,
                            tight_layout=True)

    for e_ctr, expt in enumerate(EXPERIMENTS):
        for o_ctr, odor_state in enumerate(ODOR_STATES):

            sim_id = sim_id_template.format(expt, odor_state)
            sim = session.query(models.Simulation).get(sim_id)

            sim.analysis_displacement_total_histogram.fetch_data(session)

            if PROJECTION == 'xy':
                heatmap = sim.analysis_displacement_total_histogram.xy
                extent = sim.analysis_displacement_total_histogram.extent_xy
                xlabel = 'x'
                ylabel = 'y'
            elif PROJECTION == 'xz':
                heatmap = sim.analysis_displacement_total_histogram.xz
                extent = sim.analysis_displacement_total_histogram.extent_xz
                xlabel = 'x'
                ylabel = 'z'
            elif PROJECTION == 'yz':
                heatmap = sim.analysis_displacement_total_histogram.yz
def main(n_trials, n_train_max, n_test_max, root_dir_env_var):

    # make basis functions
    basis_ins, basis_outs, max_filter_length = igfh.make_exponential_basis_functions(
        INPUT_TAUS, OUTPUT_TAUS, DOMAIN_FACTOR
    )

    for expt_id in EXPERIMENT_IDS:
        for odor_state in ODOR_STATES:

            trajs = igfh.get_trajs_with_integrated_odor_above_threshold(
                expt_id, odor_state, INTEGRATED_ODOR_THRESHOLD
            )

            train_test_ratio = (n_train_max / (n_train_max + n_test_max))
            test_train_ratio = (n_test_max / (n_train_max + n_test_max))
            n_train = min(n_train_max, np.floor(len(trajs) * train_test_ratio))
            n_test = min(n_test_max, np.floor(len(trajs) * test_train_ratio))

            trajs_trains = []
            trajs_tests = []
            glmss = []
            residualss = []

            for trial_ctr in range(n_trials):
                print('{}: odor {} (trial number: {})'.format(expt_id, odor_state, trial_ctr))

                # get random set of training and test trajectories
                perm = np.random.permutation(len(trajs))
                train_idxs = perm[:n_train]
                test_idxs = perm[-n_test:]

                trajs_train = list(np.array(trajs)[train_idxs])
                trajs_test = list(np.array(trajs)[test_idxs])

                # do some more stuff
                glms = []
                residuals = []
                for input_set, output, basis_in, basis_out in zip(INPUT_SETS, OUTPUTS, basis_ins, basis_outs):

                    # get relevant time-series data from each trajectory set
                    data_train = igfh.time_series_from_trajs(
                        trajs_train,
                        inputs=input_set,
                        output=output
                    )
                    data_test = igfh.time_series_from_trajs(
                        trajs_test,
                        inputs=input_set,
                        output=output
                    )

                    glm = fitting.GLMFitter(link=LINK, family=FAMILY)
                    glm.set_params(DELAY, basis_in=basis_in, basis_out=False)

                    glm.input_set = input_set
                    glm.output = output

                    # fit to training data
                    glm.fit(data=data_train, start=START_TIMEPOINT)

                    # predict test data
                    prediction = glm.predict(data=data_test, start=START_TIMEPOINT)
                    _, ground_truth = glm.make_feature_matrix_and_response_vector(data_test, START_TIMEPOINT)

                    # calculate residual
                    residual = np.sqrt(((prediction - ground_truth)**2).mean())

                    # clear out feature matrix and response from glm for efficient storage
                    glm.feature_matrix = None
                    glm.response_vector = None
                    glm.results.remove_data()
                    # store things
                    glms.append(glm)
                    residuals.append(residual)

                trajs_train_ids = [traj.id for traj in trajs_train]
                trajs_test_ids = [traj.id for traj in trajs_test]
                trajs_trains.append(trajs_train_ids)
                trajs_tests.append(trajs_test_ids)
                glmss.append(glms)
                residualss.append(residuals)

            # save a glm fit set
            glm_fit_set = models.GlmFitSet()

            # add data to it
            glm_fit_set.root_dir_env_var = root_dir_env_var
            glm_fit_set.path_relative = 'glm_fit'
            glm_fit_set.file_name = '{}_{}_odor_{}.pickle'.format(FIT_NAME, expt_id, odor_state)
            glm_fit_set.experiment = session.query(models.Experiment).get(expt_id)
            glm_fit_set.odor_state = odor_state
            glm_fit_set.name = FIT_NAME
            glm_fit_set.link = LINK
            glm_fit_set.family = FAMILY
            glm_fit_set.integrated_odor_threshold = INTEGRATED_ODOR_THRESHOLD
            glm_fit_set.predicted = PREDICTED
            glm_fit_set.delay = DELAY
            glm_fit_set.start_time_point = START_TIMEPOINT
            glm_fit_set.n_glms = len(glms)
            glm_fit_set.n_train = n_train
            glm_fit_set.n_test = n_test
            glm_fit_set.n_trials = n_trials

            # save data file
            glm_fit_set.save_to_file(
                input_sets=INPUT_SETS,
                outputs=OUTPUTS,
                basis_in=basis_ins,
                basis_out=basis_outs,
                trajs_train=trajs_trains,
                trajs_test=trajs_tests,
                glms=glmss,
                residuals=residualss
            )

            # save everything else (+ link to data file) in database
            session.add(glm_fit_set)

            commit(session)
from db_api.connect import session

from config import *
from config.position_heatmap import *

row_labels = ('0.3 m/s', '0.4 m/s', '0.6 m/s')
col_labels = ('on', 'none', 'afterodor')

for sim_id_template in (SIMULATION_ID_EMPIRICAL, SIMULATION_ID_INFOTAXIS):
    fig, axs = plt.subplots(3, 3, facecolor='white', figsize=FIG_SIZE, tight_layout=True)

    for e_ctr, expt in enumerate(EXPERIMENTS):
        for o_ctr, odor_state in enumerate(ODOR_STATES):

            sim_id = sim_id_template.format(expt, odor_state)
            sim = session.query(models.Simulation).get(sim_id)

            # get histogram
            hist = None
            for h in sim.analysis_displacement_after_n_timesteps_histograms:
                if h.n_timesteps == N_TIMESTEPS:
                    hist = h
                    break

            hist.fetch_data(session)

            if PROJECTION == 'xy':
                heatmap = hist.xy
                extent = hist.extent_xy
                xlabel = 'x'
                ylabel = 'y'
for ii in range(1, 5):
    axs[ii, 0] = fig.add_subplot(5, 2, 2*ii + 1, sharex=axs[0, 0], sharey=axs[0, 0])
    axs[ii, 1] = fig.add_subplot(5, 2, 2*(ii + 1), sharex=axs[0, 0], sharey=axs[0, 1])

fig_tp, axs_tp = plt.subplots(
    4, 1, figsize=FIG_SIZE_POS, facecolor=FACE_COLOR, sharex=True, sharey=True, tight_layout=True
)
fig_pos, axs_pos = plt.subplots(
    4, 1, figsize=FIG_SIZE_POS, facecolor=FACE_COLOR, sharex=True, tight_layout=True
)

wind_speed_handles = []
wind_speed_labels = []

for expt_ctr, expt in enumerate(session.query(models.Experiment)):

    print(expt.id)

    if 'fruitfly' in expt.id:

        wind_speed_labels.append('f {} m/s'.format(expt.wind_speed))

    elif 'mosquito' in expt.id:

        wind_speed_labels.append('m {} m/s'.format(expt.wind_speed))


    subset_handles = []
    subset_labels = []
    subset_handles_pos = []
from db_api.connect import session
from db_api import models

from config.wind_tunnel_discretized_matched_hit_number_histogram import *


hit_numbers = {expt: {} for expt in EXPERIMENTS}

for expt in EXPERIMENTS:
    for odor_state in ODOR_STATES:
        # get all trials for the simulation from this odor state and experiment
        sim_id = SIMULATION_ID.format(expt, odor_state)
        print('Getting hit number distribution from simulation "{}" ...'.format(sim_id))

        trials = session.query(models.Trial).filter_by(simulation_id=sim_id)

        # get the number of hits in each trial
        hits_per_trial = []
        for trial in trials:

            n_hits = np.sum([tp.detected_odor for tp in trial.get_timepoints(session)])
            hits_per_trial += [n_hits]

        hit_numbers[expt][odor_state] = hits_per_trial


# plot hit number histogram for each experiment and odor_state
fig, axs = plt.subplots(3, 3, facecolor='white', tight_layout=True)

for e_ctr, expt in enumerate(EXPERIMENTS):
              'angular_velocity_x',
              'angular_velocity_y',
              'angular_velocity_z',
              'angular_velocity_a',
              'angular_acceleration_x',
              'angular_acceleration_y',
              'angular_acceleration_z',
              'angular_acceleration_a',
              'distance_from_wall']

FACE_COLOR = 'white'
AX_SIZE = (8, 4)
DATA_COLORS = ('b', 'g', 'r')
LW = 2

expts = list(session.query(models.Experiment))
fig_size = (AX_SIZE[0], AX_SIZE[1] * len(expts))
fig, axs = plt.subplots(len(expts), 1, facecolor=FACE_COLOR, figsize=fig_size, tight_layout=True)

for quantity in QUANTITIES:
    print('Showing distributions of "{}"'.format(quantity))

    [ax.cla() for ax in axs]

    for ax, expt in zip(axs, expts):
        for color, odor_state in zip(DATA_COLORS, ODOR_STATES):

            # get the distribution meta data from the database
            tpa = session.query(models.TimepointAutocorrelation).\
                filter_by(variable=quantity, experiment=expt, odor_state=odor_state).first()
示例#43
0
FIG_SIZE_EXAMPLES = (18, 15)
FIG_SIZE_FILTERS = (14, 15)
FIG_SIZE_RESIDUALS = (14, 15)

FACE_COLOR = 'w'
FONT_SIZE = 16

MODEL_LABELS = (
    'constant',
    'x',
    'x, odor',
    'x, odor_short',
    'x, odor_long',
)

glm_fit_set = session.query(models.GlmFitSet).filter_by(
    name=FIT_NAME, experiment_id=EXPERIMENT, odor_state=ODOR_STATE).first()

start_time_point = glm_fit_set.start_time_point
delay = glm_fit_set.delay
n_glms = glm_fit_set.n_glms

# get GLMs for example trial
glms = glm_fit_set.glms[EXAMPLE_TRIAL]

# plot example trajectories
traj_train_id = glm_fit_set.trajs_train[EXAMPLE_TRIAL][EXAMPLE_TRAIN]
traj_test_id = glm_fit_set.trajs_test[EXAMPLE_TRIAL][EXAMPLE_TEST]

traj_train = session.query(models.Trajectory).get(traj_train_id)
traj_test = session.query(models.Trajectory).get(traj_test_id)
示例#44
0
from __future__ import print_function, division
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from db_api import models
from db_api.connect import session

ODOR_STATES = ('on', 'none', 'afterodor')
TIMEPOINTS_BEFORE_ENTRY = 50
TIMEPOINTS_AFTER_EXIT = 50

FACE_COLOR = 'white'
FIG_SIZE = (8, 10)
LW = 2
plt.ion()

expts = session.query(models.Experiment).all()
keep_going = True
e_ctr = 0
o_ctr = 0

fig, axs = plt.subplots(3,
                        1,
                        facecolor=FACE_COLOR,
                        figsize=FIG_SIZE,
                        tight_layout=True)
axs[2].twin = axs[2].twinx()

while keep_going:

    # get new crossing group
    expt = expts[e_ctr]
示例#45
0
def early_vs_late_heading_timecourse_x0_accounted_for(
        CROSSING_GROUP_IDS, CROSSING_GROUP_LABELS,
        X_0_MIN, X_0_MAX, H_0_MIN, H_0_MAX, CROSSING_NUMBER_MAX,
        MAX_CROSSINGS_EARLY, SUBTRACT_INITIAL_HEADING,
        T_BEFORE, T_AFTER, SCATTER_INTEGRATION_WINDOW,
        AX_SIZE, AX_GRID, EARLY_LATE_COLORS, ALPHA,
        P_VAL_COLOR, P_VAL_Y_LIM, LEGEND_CROSSING_GROUP_ID,
        FONT_SIZE):
    """
    Show early vs. late headings for different experiments, along with a plot of the
    p-values for the difference between the two means.
    """

    # convert times to time steps
    ts_before = int(round(T_BEFORE / DT))
    ts_after = int(round(T_AFTER / DT))
    scatter_ts = [ts_before + int(round(t / DT)) for t in SCATTER_INTEGRATION_WINDOW]

    # loop over crossing groups
    x_0s_dict = {}
    headings_dict = {}
    residuals_dict = {}
    p_vals_dict = {}

    scatter_ys_dict = {}
    crossing_ns_dict = {}

    for cg_id in CROSSING_GROUP_IDS:

        # get crossing group
        crossing_group = session.query(models.CrossingGroup).filter_by(id=cg_id).first()

        # get early and late crossings
        crossings_dict = {}
        crossings_all = session.query(models.Crossing).filter_by(
            crossing_group=crossing_group)
        crossings_dict['early'] = crossings_all.filter(
            models.Crossing.crossing_number <= MAX_CROSSINGS_EARLY)
        crossings_dict['late'] = crossings_all.filter(
            models.Crossing.crossing_number > MAX_CROSSINGS_EARLY,
            models.Crossing.crossing_number <= CROSSING_NUMBER_MAX)

        x_0s_dict[cg_id] = {}
        headings_dict[cg_id] = {}

        scatter_ys_dict[cg_id] = {}
        crossing_ns_dict[cg_id] = {}

        for label in ['early', 'late']:

            x_0s = []
            headings = []
            scatter_ys = []
            crossing_ns = []

            # get all initial headings, initial xs, peak concentrations, and heading time-series
            for crossing in crossings_dict[label]:

                assert crossing.crossing_number > 0
                if label == 'early':
                    assert 0 < crossing.crossing_number <= MAX_CROSSINGS_EARLY
                elif label == 'late':
                    assert MAX_CROSSINGS_EARLY < crossing.crossing_number
                    assert crossing.crossing_number <= CROSSING_NUMBER_MAX

                # throw away crossings that do not meet trigger criteria
                x_0 = getattr(
                    crossing.feature_set_basic, 'position_x_{}'.format('peak'))
                h_0 = getattr(
                    crossing.feature_set_basic, 'heading_xyz_{}'.format('peak'))

                if not (X_0_MIN <= x_0 <= X_0_MAX): continue
                if not (H_0_MIN <= h_0 <= H_0_MAX): continue

                # store x_0 (uw/dw position)
                x_0s.append(x_0)

                # get and store headings
                temp = crossing.timepoint_field(
                    session, 'heading_xyz', -ts_before, ts_after - 1,
                    'peak', 'peak', nan_pad=True)

                # subtract initial heading if desired
                if SUBTRACT_INITIAL_HEADING: temp -= temp[ts_before]

                # store headings
                headings.append(temp)

                # calculate mean heading over integration window for scatter plot
                scatter_ys.append(np.nanmean(temp[scatter_ts[0]:scatter_ts[1]]))
                crossing_ns.append(crossing.crossing_number)

            x_0s_dict[cg_id][label] = np.array(x_0s).copy()
            headings_dict[cg_id][label] = np.array(headings).copy()

            scatter_ys_dict[cg_id][label] = np.array(scatter_ys).copy()
            crossing_ns_dict[cg_id][label] = np.array(crossing_ns).copy()

        x_early = x_0s_dict[cg_id]['early']
        x_late = x_0s_dict[cg_id]['late']
        h_early = headings_dict[cg_id]['early']
        h_late = headings_dict[cg_id]['late']

        x0s_all = np.concatenate([x_early, x_late])
        hs_all = np.concatenate([h_early, h_late], axis=0)

        residuals_dict[cg_id] = {
            'early': np.nan * np.zeros(h_early.shape),
            'late': np.nan * np.zeros(h_late.shape),
        }

        # fit heading linear prediction from x0 at each time point
        # and subtract from original heading
        for t_step in range(ts_before + ts_after):

            # get all headings for this time point
            hs_t = hs_all[:, t_step]
            residuals = np.nan * np.zeros(hs_t.shape)

            # only use headings that exist
            not_nan = ~np.isnan(hs_t)

            # fit linear model
            rgr = linear_model.LinearRegression()
            rgr.fit(x0s_all[not_nan][:, None], hs_t[not_nan])

            residuals[not_nan] = hs_t[not_nan] - rgr.predict(x0s_all[not_nan][:, None])

            assert np.all(np.isnan(residuals) == np.isnan(hs_t))

            r_early, r_late = np.split(residuals, [len(x_early)])
            residuals_dict[cg_id]['early'][:, t_step] = r_early
            residuals_dict[cg_id]['late'][:, t_step] = r_late

        # loop through all time points and calculate p-value (ks-test)
        # between early and late
        p_vals = []

        for t_step in range(ts_before + ts_after):

            early_with_nans = residuals_dict[cg_id]['early'][:, t_step]
            late_with_nans = residuals_dict[cg_id]['late'][:, t_step]

            early_no_nans = early_with_nans[~np.isnan(early_with_nans)]
            late_no_nans = late_with_nans[~np.isnan(late_with_nans)]

            # calculate statistical significance
            p_vals.append(ks_2samp(early_no_nans, late_no_nans)[1])

        p_vals_dict[cg_id] = p_vals


    ## MAKE PLOTS
    t = np.arange(-ts_before, ts_after) * DT

    # history-dependence
    fig_size = (AX_SIZE[0] * AX_GRID[1], AX_SIZE[1] * AX_GRID[0])
    fig_0, axs_0 = plt.subplots(*AX_GRID, figsize=fig_size, tight_layout=True)

    for cg_id, ax in zip(CROSSING_GROUP_IDS, axs_0.flat):

        # get mean and sem of headings for early and late groups
        handles = []

        for label, color in EARLY_LATE_COLORS.items():
            headings_mean = np.nanmean(residuals_dict[cg_id][label], axis=0)
            headings_sem = stats.nansem(residuals_dict[cg_id][label], axis=0)

            handles.append(ax.plot(
                t, headings_mean, color=color, lw=2, label=label, zorder=1)[0])
            ax.fill_between(
                t, headings_mean - headings_sem, headings_mean + headings_sem,
                color=color, alpha=ALPHA, zorder=1)

        ax.set_xlabel('time since crossing (s)')

        if SUBTRACT_INITIAL_HEADING: ax.set_ylabel('heading* (deg.)')
        else: ax.set_ylabel('heading (deg.)')
        ax.set_title(CROSSING_GROUP_LABELS[cg_id])

        if cg_id == LEGEND_CROSSING_GROUP_ID:
            ax.legend(handles=handles, loc='upper right')
        set_fontsize(ax, FONT_SIZE)

        # plot p-value
        ax_twin = ax.twinx()

        ax_twin.plot(t, p_vals_dict[cg_id], color=P_VAL_COLOR, lw=2, ls='--', zorder=0)
        ax_twin.axhline(0.05, ls='-', lw=2, color='gray')

        ax_twin.set_ylim(*P_VAL_Y_LIM)
        ax_twin.set_ylabel('p-value (KS test)', fontsize=FONT_SIZE)

        set_fontsize(ax_twin, FONT_SIZE)

    fig_1, axs_1 = plt.subplots(*AX_GRID, figsize=fig_size, tight_layout=True)
    cc = np.concatenate
    colors = get_n_colors(CROSSING_NUMBER_MAX, colormap='jet')

    for cg_id, ax in zip(CROSSING_GROUP_IDS, axs_1.flat):

        # make scatter plot of x0s vs integrated headings vs crossing number
        x_0s_all = cc([x_0s_dict[cg_id]['early'], x_0s_dict[cg_id]['late']])
        ys_all = cc([scatter_ys_dict[cg_id]['early'], scatter_ys_dict[cg_id]['late']])
        cs_all = cc([crossing_ns_dict[cg_id]['early'], crossing_ns_dict[cg_id]['late']])

        cs = np.array([colors[c-1] for c in cs_all])

        hs = []

        for c in sorted(np.unique(cs_all)):
            label = 'cn = {}'.format(c)
            mask = cs_all == c
            h = ax.scatter(x_0s_all[mask], ys_all[mask],
                s=20, c=cs[mask], lw=0, label=label)
            hs.append(h)

        # calculate partial correlation between crossing number and heading given x
        not_nan = ~np.isnan(ys_all)
        r, p = stats.partial_corr(
            cs_all[not_nan], ys_all[not_nan], controls=[x_0s_all[not_nan]])

        ax.set_xlabel('x')
        ax.set_ylabel(r'$\Delta$h_mean({}:{}) (deg)'.format(
            *SCATTER_INTEGRATION_WINDOW))

        title = CROSSING_GROUP_LABELS[cg_id] + \
            ', R = {0:.2f}, P = {1:.3f}'.format(r, p)
        ax.set_title(title)

        ax.legend(handles=hs, loc='upper center', ncol=3)
        set_fontsize(ax, 16)

    return fig_0
def main():
    # get or create experiment
    experiment = get_or_create(session, models.Experiment,
                               id=EXPERIMENT_ID,
                               description=EXPERIMENT_DESCRIPTION,
                               directory_path=EXPERIMENT_DIRECTORY_PATH)

    full_directory_path = os.path.join(ARENA_DATA_DIRECTORY, EXPERIMENT_DIRECTORY_PATH)

    edr_file_names = [file_name for file_name in os.listdir(full_directory_path)
                      if file_name.lower().endswith('.edr')]

    # get all trials and add them to database
    for file_name in edr_file_names:
        print('Attempting to load file "{}"'.format(file_name))

        # skip if file already added
        if session.query(models.Trial).filter_by(file_name=file_name).all():
            print('Skipping file "{}" because it is already in the database.'.format(file_name))
            continue

        try:
            file_path = os.path.join(full_directory_path, file_name)
            _, recording_start, _, header = edr_handling.load_edr(file_path)
            recording_duration = header['recording_duration']

            # get datetime for trial pair id using regex
            trial_pair_id = re.findall(DATETIME_EXPRESSION, file_name)[0]

            # get or create trial pair
            trial_pair = get_or_create(session, models.TrialPair, id=trial_pair_id)

            # get odor
            odor = re.findall(ODOR_TYPE_EXPRESSION, file_name)[0]

            # get solenoid status
            if re.findall(SOLENOID_STATUS_EXPRESSION, file_name)[0] == 'off':
                solenoid_active = False
            else:
                solenoid_active = True

            odor_status = models.TrialOdorStatus(odor=odor, solenoid_active=solenoid_active)
            session.add(odor_status)

            # get insect number from file name using regex
            insect_number = re.findall(INSECT_NUMBER_EXPRESSION, file_name)[0]

            insect_id = '{}_{}'.format(recording_start.strftime('%Y%m%d'), insect_number)

            # get/create insect
            insect = get_or_create(session, models.Insect, id=insect_id)

            trial = models.Trial(file_name=file_name,
                                 recording_start=recording_start,
                                 recording_duration=recording_duration)

            trial.insect = insect
            trial.experiment = experiment
            trial.odor_status = odor_status
            trial.pair = trial_pair

            session.add(trial)

        except Exception, e:
            print('Error: "{}"'.format(e))
QUANTITIES = [
    'odor', 'position_x', 'position_y', 'position_z', 'velocity_x',
    'velocity_y', 'velocity_z', 'velocity_a', 'acceleration_x',
    'acceleration_y', 'acceleration_z', 'acceleration_a', 'heading_xy',
    'heading_xz', 'heading_xyz', 'angular_velocity_x', 'angular_velocity_y',
    'angular_velocity_z', 'angular_velocity_a', 'angular_acceleration_x',
    'angular_acceleration_y', 'angular_acceleration_z',
    'angular_acceleration_a', 'distance_from_wall'
]

FACE_COLOR = 'white'
AX_SIZE = (8, 4)
DATA_COLORS = ('b', 'g', 'r')
LW = 2

expts = list(session.query(models.Experiment))
fig_size = (AX_SIZE[0], AX_SIZE[1] * len(expts))
fig, axs = plt.subplots(len(expts),
                        1,
                        facecolor=FACE_COLOR,
                        figsize=fig_size,
                        tight_layout=True)

for quantity in QUANTITIES:
    print('Showing distributions of "{}"'.format(quantity))

    [ax.cla() for ax in axs]

    for ax, expt in zip(axs, expts):
        for color, odor_state in zip(DATA_COLORS, ODOR_STATES):
        'fruitfly_0.6mps_checkerboard_floor': 4,
        'mosquito_0.4mps_checkerboard_floor': 5,}


fig, axs_unflat = plt.subplots(
    3, 2, facecolor=FACE_COLOR, figsize=FIG_SIZE, sharex=True,
    tight_layout=True
)

axs = axs_unflat.flatten()
axs_twin = [ax.twinx() for ax in axs[2:]]

wind_speed_handles = []
wind_speed_labels = []

for expt in session.query(models.Experiment):

    if 'fly' in expt.id:

        wind_speed_label = 'fly {} m/s'.format(expt.wind_speed)

    elif 'mosquito' in expt.id:

        wind_speed_label = 'mosquito {} m/s'.format(expt.wind_speed)

    print(expt.id)

    # get threshold and crossing group for this experiment
    threshold = session.query(models.Threshold).filter_by(
        experiment=expt, determination=DETERMINATION
    ).first()
示例#49
0
"""
Plot the discriminability of two response sets when they are classified by thresholding.
"""
from __future__ import print_function, division
import numpy as np
import matplotlib.pyplot as plt
from db_api.connect import session
from db_api import models

RESPONSE_VAR = 'heading_xyz'

FACE_COLOR = 'white'
AX_SIZE = (6, 3)

expts = session.query(models.Experiment).all()

fig_size = (2 * AX_SIZE[0], len(expts) * AX_SIZE[1])
fig, axs = plt.subplots(len(expts),
                        2,
                        facecolor=FACE_COLOR,
                        figsize=fig_size,
                        tight_layout=True)

for expt, ax_row in zip(expts, axs):

    cgs = session.query(models.CrossingGroup).filter_by(experiment=expt,
                                                        odor_state='on')

    for cg, ax in zip(cgs, ax_row):
        disc_ths = session.query(models.DiscriminationThreshold).\
            filter_by(crossing_group=cg, variable=RESPONSE_VAR)