Exemple #1
0
    def test_nansem_gives_same_as_scipy_stats_sem(self):

        # 1d array with no nans
        x = np.random.normal(0, 1, 1000)
        self.assertAlmostEqual(scipy.stats.sem(x), stats.nansem(x))

        # 1d array with nans
        x = np.concatenate([x, np.nan * np.ones(100,)])
        self.assertAlmostEqual(scipy.stats.sem(x[:1000]), stats.nansem(x))

        # 2d array with no nans
        x = np.random.normal(0, 2, (50, 50))
        self.assertAlmostEqual(scipy.stats.sem(x, axis=None), stats.nansem(x, axis=None))
        np.testing.assert_array_almost_equal(scipy.stats.sem(x, axis=0), stats.nansem(x, axis=0))
        np.testing.assert_array_almost_equal(scipy.stats.sem(x, axis=1), stats.nansem(x, axis=1))

        # 2d array with nans
        y = x.copy()
        y[-5:, :] = np.nan
        self.assertAlmostEqual(scipy.stats.sem(y[:-5, :], axis=None), stats.nansem(y, axis=None))
        np.testing.assert_array_almost_equal(scipy.stats.sem(y[:-5, :], axis=0),
                                             stats.nansem(y, axis=0))

        y = x.copy()
        y[:, -5:] = np.nan
        self.assertAlmostEqual(scipy.stats.sem(y[:, :-5], axis=None), stats.nansem(y, axis=None))
        np.testing.assert_array_almost_equal(scipy.stats.sem(y[:, :-5], axis=1),
                                             stats.nansem(y, axis=1))
def get_time_avg_response_diff_and_bounds(responses_below, responses_above):
    """
    Calculate time-averaged response (e.g., heading) difference between response_below and
    response_above matrices.

    In calculating the difference response_above is subtracted from response_below.
    If either responses_below or responses_above has only one response, the lower and upper bounds
    are returned as Nones.

    :param responses_below: 2D array of responses "below threhsold", can have nans
    :param responses_above: 2D array of responses "above threhsold", can have nans
    :return: difference, lower bound on difference, upper bound on difference
    """

    # get means
    mean_below = np.nanmean(responses_below, axis=0)
    mean_above = np.nanmean(responses_above, axis=0)

    # zero means to first entry
    mean_below -= mean_below[0]
    mean_above -= mean_above[0]

    # get standard errors of the mean
    sem_below = stats.nansem(responses_below, axis=0)
    sem_above = stats.nansem(responses_above, axis=0)

    # get mean difference
    diff = np.nanmean((mean_below - mean_above))

    # get bounds
    if len(responses_below) in [0, 1] or len(responses_above) in [0, 1]:
        lb = None
        ub = None
    else:
        lb = np.nanmean((mean_below - sem_below) - (mean_above + sem_above))
        ub = np.nanmean((mean_below + sem_below) - (mean_above - sem_above))

    return diff, lb, ub
Exemple #3
0
                                                nan_pad=True)

            if SUBTRACT_INITIAL_HEADING:

                headings -= headings[N_TIMESTEPS_BEFORE]

            crossings_time_series[crossing_ctr, :] = headings
            crossing_ctr += 1

        crossings_time_series = crossings_time_series[:crossing_ctr, :]

        print('{} total crossings'.format(len(crossings_time_series)))

        # get mean and sem
        crossings_mean = np.nanmean(crossings_time_series, axis=0)
        crossings_sem = stats.nansem(crossings_time_series, axis=0)
        crossings_std = np.nanstd(crossings_time_series, axis=0)

        # make plot
        handle, = axs[0, col_ctr].plot(time_vec,
                                       crossings_mean,
                                       lw=2,
                                       color=EXPT_COLORS[expt.id])
        axs[0, col_ctr].fill_between(time_vec,
                                     crossings_mean - crossings_std,
                                     crossings_mean + crossings_std,
                                     color=EXPT_COLORS[expt.id],
                                     alpha=0.3)

        if col_ctr == 0:
            wind_speed_handles.append(handle)
Exemple #4
0
        var_after = crossing.timepoint_field(session,
                                             VARIABLE,
                                             first=0,
                                             last=DISPLAY_END - 1,
                                             first_rel_to='peak',
                                             last_rel_to='peak')

        var_array[ctr,
                  -DISPLAY_START - len(var_before):-DISPLAY_START] = var_before
        var_array[ctr,
                  -DISPLAY_START:-DISPLAY_START + len(var_after)] = var_after

# calculate mean, std, and sem of these arrays
mean_below = np.nanmean(var_array_below, axis=0)
std_below = np.nanstd(var_array_below, axis=0)
sem_below = stats.nansem(var_array_below, axis=0)

mean_above = np.nanmean(var_array_above, axis=0)
std_above = np.nanstd(var_array_above, axis=0)
sem_above = stats.nansem(var_array_above, axis=0)

# zero means around their value at their peaks
mean_below -= mean_below[-DISPLAY_START]
mean_above -= mean_above[-DISPLAY_START]

# get integrals
integral_mean = mean_above[INTEGRAL_START - DISPLAY_START:INTEGRAL_END - DISPLAY_START] - \
    mean_below[INTEGRAL_START - DISPLAY_START:INTEGRAL_END - DISPLAY_START]

# plot things nicely
fig, ax = plt.subplots(1,
for crossings, var_array in zip([crossings_below, crossings_above],
                                [var_array_below, var_array_above]):
    for ctr, crossing in enumerate(crossings):

        var_before = crossing.timepoint_field(session, VARIABLE, first=DISPLAY_START, last=-1,
                                              first_rel_to='peak', last_rel_to='peak')
        var_after = crossing.timepoint_field(session, VARIABLE, first=0, last=DISPLAY_END - 1,
                                             first_rel_to='peak', last_rel_to='peak')

        var_array[ctr, -DISPLAY_START - len(var_before):-DISPLAY_START] = var_before
        var_array[ctr, -DISPLAY_START:-DISPLAY_START + len(var_after)] = var_after

# calculate mean, std, and sem of these arrays
mean_below = np.nanmean(var_array_below, axis=0)
std_below = np.nanstd(var_array_below, axis=0)
sem_below = stats.nansem(var_array_below, axis=0)

mean_above = np.nanmean(var_array_above, axis=0)
std_above = np.nanstd(var_array_above, axis=0)
sem_above = stats.nansem(var_array_above, axis=0)

# zero means around their value at their peaks
mean_below -= mean_below[-DISPLAY_START]
mean_above -= mean_above[-DISPLAY_START]

# get integrals
integral_mean = mean_above[INTEGRAL_START - DISPLAY_START:INTEGRAL_END - DISPLAY_START] - \
    mean_below[INTEGRAL_START - DISPLAY_START:INTEGRAL_END - DISPLAY_START]

# plot things nicely
fig, ax = plt.subplots(1, 1, facecolor=FACE_COLOR, figsize=FIG_SIZE, tight_layout=True)
def crossing_triggered_headings_early_late_surge(
        SEED, N_TRAJS, DURATION, DT, BOUNDS, TAU, NOISE, BIAS, AGENT_THRESHOLD,
        SURGE_AMP, TAU_SURGE, PL_CONC, PL_MEAN, PL_STD, ANALYSIS_THRESHOLD,
        H_MIN_PEAK, H_MAX_PEAK, X_MIN_PEAK, X_MAX_PEAK, EARLY_LESS_THAN,
        SUBTRACT_PEAK_HEADING, T_BEFORE, T_AFTER, SAVE_FILE):
    """
    Fly several agents through a simulated plume and plot their plume-crossing-triggered
    headings.
    """

    # build plume and agent
    pl = GaussianLaminarPlume(PL_CONC, PL_MEAN, PL_STD)

    ag = SurgingAgent(tau=TAU,
                      noise=NOISE,
                      bias=BIAS,
                      threshold=AGENT_THRESHOLD,
                      hit_trigger='peak',
                      surge_amp=SURGE_AMP,
                      tau_surge=TAU_SURGE,
                      bounds=BOUNDS)

    # GENERATE TRAJECTORIES
    np.random.seed(SEED)

    trajs = []

    for _ in range(N_TRAJS):

        # choose random start position
        start_pos = np.array([
            np.random.uniform(*BOUNDS[0]),
            np.random.uniform(*BOUNDS[1]),
            np.random.uniform(*BOUNDS[2]),
        ])

        # make trajectory
        traj = ag.track(plume=pl,
                        start_pos=start_pos,
                        duration=DURATION,
                        dt=DT)
        traj['headings'] = heading(traj['vs'])[:, 2]
        trajs.append(traj)

    # ANALYZE TRAJECTORIES
    n_crossings = []

    # collect early and late crossings
    crossings_early = []
    crossings_late = []

    crossings_save = []

    ts_before = int(T_BEFORE / DT)
    ts_after = int(T_AFTER / DT)

    for traj in trajs:

        starts, onsets, peak_times, offsets, ends = \
            segment_by_threshold(traj['odors'], ANALYSIS_THRESHOLD)[0].T

        n_crossings.append(len(peak_times))

        for ctr, (start, peak_time,
                  end) in enumerate(zip(starts, peak_times, ends)):

            # skip crossings that don't meet inclusion criteria
            if not (H_MIN_PEAK <= traj['headings'][peak_time] < H_MAX_PEAK):
                continue
            if not (X_MIN_PEAK <= traj['xs'][peak_time, 0] < X_MAX_PEAK):
                continue

            crossing = np.nan * np.zeros((ts_before + ts_after, ))

            ts_before_crossing = peak_time - start
            ts_after_crossing = end - peak_time

            if ts_before_crossing >= ts_before:
                crossing[:ts_before] = traj['headings'][peak_time -
                                                        ts_before:peak_time]
            else:
                crossing[ts_before - ts_before_crossing:ts_before] = \
                    traj['headings'][start:peak_time]

            if ts_after_crossing >= ts_after:
                crossing[ts_before:] = traj['headings'][peak_time:peak_time +
                                                        ts_after]
            else:
                crossing[ts_before:ts_before + ts_after_crossing] = \
                    traj['headings'][peak_time:end]

            if SUBTRACT_PEAK_HEADING:
                crossing -= crossing[ts_before]
            if ctr + 1 < EARLY_LESS_THAN:
                crossings_early.append(crossing)
            else:
                crossings_late.append(crossing)

            crossings_save.append((ctr + 1, crossing.copy()))

    # save crossings
    save_dict_full = {
        'ts_before': ts_before,
        'ts_after': ts_after,
        'crossings': crossings_save
    }
    save_file = SAVE_FILE + '_full.npy'
    np.save(save_file, np.array([save_dict_full]))

    n_crossings = np.array(n_crossings)

    crossings_early = np.array(crossings_early)
    crossings_late = np.array(crossings_late)

    t = np.arange(-ts_before, ts_after) * DT

    p_vals = get_ks_p_vals(crossings_early, crossings_late)

    h_mean_early = np.nanmean(crossings_early, axis=0)
    h_sem_early = nansem(crossings_early, axis=0)

    h_mean_late = np.nanmean(crossings_late, axis=0)
    h_sem_late = nansem(crossings_late, axis=0)

    save_data = {'t': t, 'early': h_mean_early, 'late': h_mean_late}
    np.save(SAVE_FILE + '.npy', np.array([save_data]))

    fig, axs = plt.figure(figsize=(15, 15), tight_layout=True), []

    axs.append(fig.add_subplot(3, 2, 1))
    axs.append(fig.add_subplot(3, 2, 2))

    handles = []

    try:
        handles.append(axs[0].plot(t,
                                   h_mean_early,
                                   lw=3,
                                   color='b',
                                   label='early')[0])
        axs[0].fill_between(t,
                            h_mean_early - h_sem_early,
                            h_mean_early + h_sem_early,
                            color='b',
                            alpha=0.2)
    except:
        pass

    try:
        handles.append(axs[0].plot(t,
                                   h_mean_late,
                                   lw=3,
                                   color='g',
                                   label='late')[0])
        axs[0].fill_between(t,
                            h_mean_late - h_sem_late,
                            h_mean_late + h_sem_late,
                            color='g',
                            alpha=0.2)
    except:
        pass

    # axs[0].axvline(0, ls='--', color='gray')

    ## get y-position to plot p-vals at
    y_min, y_max = axs[0].get_ylim()
    y_range = y_max - y_min

    y_p_vals = (y_min + 0.02 * y_range) * np.ones(len(p_vals))
    y_p_vals_10 = y_p_vals.copy()
    y_p_vals_05 = y_p_vals.copy()
    y_p_vals_01 = y_p_vals.copy()
    y_p_vals_10[p_vals > 0.1] = np.nan
    y_p_vals_05[p_vals > 0.05] = np.nan
    y_p_vals_01[p_vals > 0.01] = np.nan

    axs[0].plot(t, y_p_vals_10, lw=4, color='gray')
    axs[0].plot(t, y_p_vals_05, lw=4, color=(1, 0, 0))
    axs[0].plot(t, y_p_vals_01, lw=4, color=(.25, 0, 0))

    axs[0].set_xlabel('time since peak (s)')

    if SUBTRACT_PEAK_HEADING:
        axs[0].set_ylabel('change in heading (deg)')
    else:
        axs[0].set_ylabel('heading (deg)')

    axs[0].legend(handles=handles, fontsize=16)

    bin_min = -0.5
    bin_max = n_crossings.max() + 0.5

    bins = np.linspace(bin_min, bin_max, bin_max - bin_min + 1, endpoint=True)

    axs[1].hist(n_crossings, bins=bins, lw=0, normed=True)
    axs[1].set_xlim(bin_min, bin_max)

    axs[1].set_xlabel('number of crossings')
    axs[1].set_ylabel('proportion of trajectories')

    axs.append(fig.add_subplot(3, 1, 2))

    axs[2].plot(trajs[0]['xs'][:, 0], trajs[0]['xs'][:, 1])
    axs[2].axhline(0, color='gray', ls='--')

    axs[2].set_xlabel('x (m)')
    axs[2].set_ylabel('y (m)')

    axs.append(fig.add_subplot(3, 1, 3))

    all_xy = np.concatenate([traj['xs'][:, :2] for traj in trajs[:3000]],
                            axis=0)
    x_bins = np.linspace(BOUNDS[0][0], BOUNDS[0][1], 66, endpoint=True)
    y_bins = np.linspace(BOUNDS[1][0], BOUNDS[1][1], 30, endpoint=True)

    axs[3].hist2d(all_xy[:, 0], all_xy[:, 1], bins=(x_bins, y_bins))

    axs[3].set_xlabel('x (m)')
    axs[3].set_ylabel('y (m)')

    for ax in axs:

        set_font_size(ax, 20)

    return fig
def crossing_triggered_headings_early_late_vary_param(
        SEED, SAVE_FILE, N_TRAJS, DURATION, DT, TAU, NOISE, BIAS,
        HIT_INFLUENCE, SQRT_K_0, VARIABLE_PARAMS, BOUNDS, PL_CONC, PL_MEAN,
        PL_STD, H_MIN_PEAK, H_MAX_PEAK, X_MIN_PEAK, X_MAX_PEAK,
        EARLY_LESS_THAN, SUBTRACT_PEAK_HEADING, T_BEFORE, T_AFTER, T_INT_START,
        T_INT_END, AX_GRID):
    """
    Fly several agents through a simulated plume and plot their plume-crossing-triggered
    headings.
    """

    # try to open saved results

    if os.path.isfile(SAVE_FILE):

        print('Results file found. Loading results file.')
        results = np.load(SAVE_FILE)

    else:

        print('Results file not found. Running analysis...')
        np.random.seed(SEED)

        # build plume

        pl = GaussianLaminarPlume(PL_CONC, PL_MEAN, PL_STD)

        # loop over all parameter sets

        varying_params = []
        fixed_params = []

        early_late_heading_diffs_all = []
        early_late_heading_diffs_lb_all = []
        early_late_heading_diffs_ub_all = []

        for variable_params in VARIABLE_PARAMS:

            print('Variable params: {}'.format(variable_params))

            assert set(variable_params.keys()) == set(
                ['threshold', 'tau_memory', 'sqrt_k_s'])

            # identify which parameter is varying

            for key, vals in variable_params.items():

                if isinstance(vals, list):

                    varying_params.append((key, vals))
                    fixed_params.append(([(k, v)
                                          for k, v in variable_params.items()
                                          if k != key]))

                    n_param_sets = len(vals)

                    break

            # make other parameters into lists so they can all be looped over nicely

            for key, vals in variable_params.items():

                if not isinstance(vals, list):

                    variable_params[key] = [vals for _ in range(n_param_sets)]

            early_late_heading_diffs = []
            early_late_heading_diffs_lb = []
            early_late_heading_diffs_ub = []

            for param_set_ctr in range(len(variable_params.values()[0])):

                threshold = variable_params['threshold'][param_set_ctr]
                hit_influence = HIT_INFLUENCE
                tau_memory = variable_params['tau_memory'][param_set_ctr]
                k_0 = np.array([
                    [SQRT_K_0**2, 0],
                    [0, SQRT_K_0**2],
                ])
                k_s = np.array([
                    [variable_params['sqrt_k_s'][param_set_ctr]**2, 0],
                    [0, variable_params['sqrt_k_s'][param_set_ctr]**2],
                ])

                # build tracking agent

                ag = CenterlineInferringAgent(tau=TAU,
                                              noise=NOISE,
                                              bias=BIAS,
                                              threshold=threshold,
                                              hit_trigger='peak',
                                              hit_influence=hit_influence,
                                              tau_memory=tau_memory,
                                              k_0=k_0,
                                              k_s=k_s,
                                              bounds=BOUNDS)

                trajs = []

                for _ in range(N_TRAJS):

                    # choose random start position

                    start_pos = np.array([
                        np.random.uniform(*BOUNDS[0]),
                        np.random.uniform(*BOUNDS[1]),
                        np.random.uniform(*BOUNDS[2]),
                    ])

                    # make trajectory

                    traj = ag.track(plume=pl,
                                    start_pos=start_pos,
                                    duration=DURATION,
                                    dt=DT)

                    traj['headings'] = heading(traj['vs'])[:, 2]

                    trajs.append(traj)

                crossings_early = []
                crossings_late = []

                ts_before = int(T_BEFORE / DT)
                ts_after = int(T_AFTER / DT)

                for traj in trajs:

                    starts, onsets, peak_times, offsets, ends = \
                        segment_by_threshold(traj['odors'], threshold)[0].T

                    for ctr, (start, peak_time,
                              end) in enumerate(zip(starts, peak_times, ends)):

                        if not (H_MIN_PEAK <= traj['headings'][peak_time] <
                                H_MAX_PEAK):

                            continue

                        if not (X_MIN_PEAK <= traj['xs'][peak_time, 0] <
                                X_MAX_PEAK):

                            continue

                        crossing = np.nan * np.zeros((ts_before + ts_after, ))

                        ts_before_crossing = peak_time - start
                        ts_after_crossing = end - peak_time

                        if ts_before_crossing >= ts_before:

                            crossing[:ts_before] = traj['headings'][
                                peak_time - ts_before:peak_time]

                        else:

                            crossing[ts_before - ts_before_crossing:ts_before] = \
                                traj['headings'][start:peak_time]

                        if ts_after_crossing >= ts_after:

                            crossing[ts_before:] = traj['headings'][
                                peak_time:peak_time + ts_after]

                        else:

                            crossing[ts_before:ts_before + ts_after_crossing] = \
                                traj['headings'][peak_time:end]

                        if SUBTRACT_PEAK_HEADING:

                            crossing -= crossing[ts_before]

                        if ctr < EARLY_LESS_THAN:

                            crossings_early.append(crossing)

                        else:

                            crossings_late.append(crossing)

                crossings_early = np.array(crossings_early)
                crossings_late = np.array(crossings_late)

                t = np.arange(-ts_before, ts_after) * DT

                h_mean_early = np.nanmean(crossings_early, axis=0)
                h_mean_late = np.nanmean(crossings_late, axis=0)

                h_sem_early = nansem(crossings_early, axis=0)
                h_sem_late = nansem(crossings_late, axis=0)

                h_mean_diff = h_mean_late - h_mean_early

                h_mean_diff_lb = h_mean_late - h_sem_late - (h_mean_early +
                                                             h_sem_early)
                h_mean_diff_ub = h_mean_late + h_sem_late - (h_mean_early -
                                                             h_sem_early)

                early_late_heading_diff = \
                    h_mean_diff[(t > T_INT_START) * (t <= T_INT_END)].mean()
                early_late_heading_diff_lb = \
                    h_mean_diff_lb[(t > T_INT_START) * (t <= T_INT_END)].mean()
                early_late_heading_diff_ub = \
                    h_mean_diff_ub[(t > T_INT_START) * (t <= T_INT_END)].mean()

                early_late_heading_diffs.append(early_late_heading_diff)
                early_late_heading_diffs_lb.append(early_late_heading_diff_lb)
                early_late_heading_diffs_ub.append(early_late_heading_diff_ub)

            early_late_heading_diffs_all.append(
                np.array(early_late_heading_diffs))
            early_late_heading_diffs_lb_all.append(
                np.array(early_late_heading_diffs_lb))
            early_late_heading_diffs_ub_all.append(
                np.array(early_late_heading_diffs_ub))

        # save results

        results = np.array([{
            'varying_params':
            varying_params,
            'fixed_params':
            fixed_params,
            'early_late_heading_diffs_all':
            early_late_heading_diffs_all,
            'early_late_heading_diffs_lb_all':
            early_late_heading_diffs_lb_all,
            'early_late_heading_diffs_ub_all':
            early_late_heading_diffs_ub_all,
        }])
        np.save(SAVE_FILE, results)

    results = results[0]

    ## MAKE PLOTS

    fig_size = (5 * AX_GRID[1], 4 * AX_GRID[0])
    fig, axs = plt.subplots(*AX_GRID, figsize=fig_size, tight_layout=True)

    for ax_ctr in range(len(results['varying_params'])):

        ax = axs.flatten()[ax_ctr]

        ys_plot = results['early_late_heading_diffs_all'][ax_ctr]
        ys_err = [
            ys_plot - results['early_late_heading_diffs_lb_all'][ax_ctr],
            results['early_late_heading_diffs_ub_all'][ax_ctr] - ys_plot
        ]

        xs_name = results['varying_params'][ax_ctr][0]
        xs_plot = np.arange(len(ys_plot))

        ax.errorbar(xs_plot, ys_plot, yerr=ys_err, color='k', fmt='--o')

        ax.axhline(0, color='gray')

        if np.max(results['early_late_heading_diffs_ub_all'][ax_ctr]) > 0:

            y_range = np.max(results['early_late_heading_diffs_ub_all'][ax_ctr]) - \
                      np.min(results['early_late_heading_diffs_lb_all'][ax_ctr])

        else:

            y_range = -np.min(
                results['early_late_heading_diffs_lb_all'][ax_ctr])

        y_min = np.min(
            results['early_late_heading_diffs_lb_all'][ax_ctr]) - 0.1 * y_range
        y_max = max(np.max(results['early_late_heading_diffs_ub_all'][ax_ctr]),
                    0) + 0.1 * y_range

        ax.set_xlim(-1, len(ys_plot))
        ax.set_xticks(xs_plot)

        x_ticklabels = results['varying_params'][ax_ctr][1]

        if xs_name == 'threshold':

            x_ticklabels = [
                '{0:.4f}'.format(xtl * (0.0476 / 526)) for xtl in x_ticklabels
            ]

        ax.set_xticklabels(x_ticklabels)

        ax.set_ylim(y_min, y_max)

        if xs_name == 'tau_memory': x_label = 'tau_m (s)'
        elif xs_name == 'threshold': x_label = 'threshold (% ethanol)'
        else: x_label = xs_name

        ax.set_xlabel(x_label)
        ax.set_ylabel('mean heading difference\nfor late vs. early crossings')

    for ax in axs.flatten():

        set_font_size(ax, 16)

    return fig
Exemple #8
0
def early_vs_late_heading_timecourse_x0_accounted_for(
        CROSSING_GROUP_IDS, CROSSING_GROUP_LABELS,
        X_0_MIN, X_0_MAX, H_0_MIN, H_0_MAX, CROSSING_NUMBER_MAX,
        MAX_CROSSINGS_EARLY, SUBTRACT_INITIAL_HEADING,
        T_BEFORE, T_AFTER, SCATTER_INTEGRATION_WINDOW,
        AX_SIZE, AX_GRID, EARLY_LATE_COLORS, ALPHA,
        P_VAL_COLOR, P_VAL_Y_LIM, LEGEND_CROSSING_GROUP_ID,
        FONT_SIZE):
    """
    Show early vs. late headings for different experiments, along with a plot of the
    p-values for the difference between the two means.
    """

    # convert times to time steps
    ts_before = int(round(T_BEFORE / DT))
    ts_after = int(round(T_AFTER / DT))
    scatter_ts = [ts_before + int(round(t / DT)) for t in SCATTER_INTEGRATION_WINDOW]

    # loop over crossing groups
    x_0s_dict = {}
    headings_dict = {}
    residuals_dict = {}
    p_vals_dict = {}

    scatter_ys_dict = {}
    crossing_ns_dict = {}

    for cg_id in CROSSING_GROUP_IDS:

        # get crossing group
        crossing_group = session.query(models.CrossingGroup).filter_by(id=cg_id).first()

        # get early and late crossings
        crossings_dict = {}
        crossings_all = session.query(models.Crossing).filter_by(
            crossing_group=crossing_group)
        crossings_dict['early'] = crossings_all.filter(
            models.Crossing.crossing_number <= MAX_CROSSINGS_EARLY)
        crossings_dict['late'] = crossings_all.filter(
            models.Crossing.crossing_number > MAX_CROSSINGS_EARLY,
            models.Crossing.crossing_number <= CROSSING_NUMBER_MAX)

        x_0s_dict[cg_id] = {}
        headings_dict[cg_id] = {}

        scatter_ys_dict[cg_id] = {}
        crossing_ns_dict[cg_id] = {}

        for label in ['early', 'late']:

            x_0s = []
            headings = []
            scatter_ys = []
            crossing_ns = []

            # get all initial headings, initial xs, peak concentrations, and heading time-series
            for crossing in crossings_dict[label]:

                assert crossing.crossing_number > 0
                if label == 'early':
                    assert 0 < crossing.crossing_number <= MAX_CROSSINGS_EARLY
                elif label == 'late':
                    assert MAX_CROSSINGS_EARLY < crossing.crossing_number
                    assert crossing.crossing_number <= CROSSING_NUMBER_MAX

                # throw away crossings that do not meet trigger criteria
                x_0 = getattr(
                    crossing.feature_set_basic, 'position_x_{}'.format('peak'))
                h_0 = getattr(
                    crossing.feature_set_basic, 'heading_xyz_{}'.format('peak'))

                if not (X_0_MIN <= x_0 <= X_0_MAX): continue
                if not (H_0_MIN <= h_0 <= H_0_MAX): continue

                # store x_0 (uw/dw position)
                x_0s.append(x_0)

                # get and store headings
                temp = crossing.timepoint_field(
                    session, 'heading_xyz', -ts_before, ts_after - 1,
                    'peak', 'peak', nan_pad=True)

                # subtract initial heading if desired
                if SUBTRACT_INITIAL_HEADING: temp -= temp[ts_before]

                # store headings
                headings.append(temp)

                # calculate mean heading over integration window for scatter plot
                scatter_ys.append(np.nanmean(temp[scatter_ts[0]:scatter_ts[1]]))
                crossing_ns.append(crossing.crossing_number)

            x_0s_dict[cg_id][label] = np.array(x_0s).copy()
            headings_dict[cg_id][label] = np.array(headings).copy()

            scatter_ys_dict[cg_id][label] = np.array(scatter_ys).copy()
            crossing_ns_dict[cg_id][label] = np.array(crossing_ns).copy()

        x_early = x_0s_dict[cg_id]['early']
        x_late = x_0s_dict[cg_id]['late']
        h_early = headings_dict[cg_id]['early']
        h_late = headings_dict[cg_id]['late']

        x0s_all = np.concatenate([x_early, x_late])
        hs_all = np.concatenate([h_early, h_late], axis=0)

        residuals_dict[cg_id] = {
            'early': np.nan * np.zeros(h_early.shape),
            'late': np.nan * np.zeros(h_late.shape),
        }

        # fit heading linear prediction from x0 at each time point
        # and subtract from original heading
        for t_step in range(ts_before + ts_after):

            # get all headings for this time point
            hs_t = hs_all[:, t_step]
            residuals = np.nan * np.zeros(hs_t.shape)

            # only use headings that exist
            not_nan = ~np.isnan(hs_t)

            # fit linear model
            rgr = linear_model.LinearRegression()
            rgr.fit(x0s_all[not_nan][:, None], hs_t[not_nan])

            residuals[not_nan] = hs_t[not_nan] - rgr.predict(x0s_all[not_nan][:, None])

            assert np.all(np.isnan(residuals) == np.isnan(hs_t))

            r_early, r_late = np.split(residuals, [len(x_early)])
            residuals_dict[cg_id]['early'][:, t_step] = r_early
            residuals_dict[cg_id]['late'][:, t_step] = r_late

        # loop through all time points and calculate p-value (ks-test)
        # between early and late
        p_vals = []

        for t_step in range(ts_before + ts_after):

            early_with_nans = residuals_dict[cg_id]['early'][:, t_step]
            late_with_nans = residuals_dict[cg_id]['late'][:, t_step]

            early_no_nans = early_with_nans[~np.isnan(early_with_nans)]
            late_no_nans = late_with_nans[~np.isnan(late_with_nans)]

            # calculate statistical significance
            p_vals.append(ks_2samp(early_no_nans, late_no_nans)[1])

        p_vals_dict[cg_id] = p_vals


    ## MAKE PLOTS
    t = np.arange(-ts_before, ts_after) * DT

    # history-dependence
    fig_size = (AX_SIZE[0] * AX_GRID[1], AX_SIZE[1] * AX_GRID[0])
    fig_0, axs_0 = plt.subplots(*AX_GRID, figsize=fig_size, tight_layout=True)

    for cg_id, ax in zip(CROSSING_GROUP_IDS, axs_0.flat):

        # get mean and sem of headings for early and late groups
        handles = []

        for label, color in EARLY_LATE_COLORS.items():
            headings_mean = np.nanmean(residuals_dict[cg_id][label], axis=0)
            headings_sem = stats.nansem(residuals_dict[cg_id][label], axis=0)

            handles.append(ax.plot(
                t, headings_mean, color=color, lw=2, label=label, zorder=1)[0])
            ax.fill_between(
                t, headings_mean - headings_sem, headings_mean + headings_sem,
                color=color, alpha=ALPHA, zorder=1)

        ax.set_xlabel('time since crossing (s)')

        if SUBTRACT_INITIAL_HEADING: ax.set_ylabel('heading* (deg.)')
        else: ax.set_ylabel('heading (deg.)')
        ax.set_title(CROSSING_GROUP_LABELS[cg_id])

        if cg_id == LEGEND_CROSSING_GROUP_ID:
            ax.legend(handles=handles, loc='upper right')
        set_fontsize(ax, FONT_SIZE)

        # plot p-value
        ax_twin = ax.twinx()

        ax_twin.plot(t, p_vals_dict[cg_id], color=P_VAL_COLOR, lw=2, ls='--', zorder=0)
        ax_twin.axhline(0.05, ls='-', lw=2, color='gray')

        ax_twin.set_ylim(*P_VAL_Y_LIM)
        ax_twin.set_ylabel('p-value (KS test)', fontsize=FONT_SIZE)

        set_fontsize(ax_twin, FONT_SIZE)

    fig_1, axs_1 = plt.subplots(*AX_GRID, figsize=fig_size, tight_layout=True)
    cc = np.concatenate
    colors = get_n_colors(CROSSING_NUMBER_MAX, colormap='jet')

    for cg_id, ax in zip(CROSSING_GROUP_IDS, axs_1.flat):

        # make scatter plot of x0s vs integrated headings vs crossing number
        x_0s_all = cc([x_0s_dict[cg_id]['early'], x_0s_dict[cg_id]['late']])
        ys_all = cc([scatter_ys_dict[cg_id]['early'], scatter_ys_dict[cg_id]['late']])
        cs_all = cc([crossing_ns_dict[cg_id]['early'], crossing_ns_dict[cg_id]['late']])

        cs = np.array([colors[c-1] for c in cs_all])

        hs = []

        for c in sorted(np.unique(cs_all)):
            label = 'cn = {}'.format(c)
            mask = cs_all == c
            h = ax.scatter(x_0s_all[mask], ys_all[mask],
                s=20, c=cs[mask], lw=0, label=label)
            hs.append(h)

        # calculate partial correlation between crossing number and heading given x
        not_nan = ~np.isnan(ys_all)
        r, p = stats.partial_corr(
            cs_all[not_nan], ys_all[not_nan], controls=[x_0s_all[not_nan]])

        ax.set_xlabel('x')
        ax.set_ylabel(r'$\Delta$h_mean({}:{}) (deg)'.format(
            *SCATTER_INTEGRATION_WINDOW))

        title = CROSSING_GROUP_LABELS[cg_id] + \
            ', R = {0:.2f}, P = {1:.3f}'.format(r, p)
        ax.set_title(title)

        ax.legend(handles=hs, loc='upper center', ncol=3)
        set_fontsize(ax, 16)

    return fig_0
def crossing_triggered_headings_all(
        SEED, N_TRAJS, DURATION, DT, TAU, NOISE, BIAS, AGENT_THRESHOLD,
        HIT_INFLUENCE, TAU_MEMORY, K_0, K_S, BOUNDS, PL_CONC, PL_MEAN, PL_STD,
        ANALYSIS_THRESHOLD, H_MIN_PEAK, H_MAX_PEAK, SUBTRACT_PEAK_HEADING,
        T_BEFORE, T_AFTER, Y_LIM):
    """
    Fly several agents through a simulated plume and plot their plume-crossing-triggered
    headings.
    """

    # build plume and agent

    pl = GaussianLaminarPlume(PL_CONC, PL_MEAN, PL_STD)

    k_0 = K_0 * np.eye(2)
    k_s = K_S * np.eye(2)

    ag = CenterlineInferringAgent(tau=TAU,
                                  noise=NOISE,
                                  bias=BIAS,
                                  threshold=AGENT_THRESHOLD,
                                  hit_trigger='peak',
                                  hit_influence=HIT_INFLUENCE,
                                  k_0=k_0,
                                  k_s=k_s,
                                  tau_memory=TAU_MEMORY,
                                  bounds=BOUNDS)

    # generate trajectories

    np.random.seed(SEED)

    trajs = []

    for _ in range(N_TRAJS):

        # choose random start position

        start_pos = np.array([
            np.random.uniform(*BOUNDS[0]),
            np.random.uniform(*BOUNDS[1]),
            np.random.uniform(*BOUNDS[2]),
        ])

        # make trajectory

        traj = ag.track(plume=pl,
                        start_pos=start_pos,
                        duration=DURATION,
                        dt=DT)

        traj['headings'] = heading(traj['vs'])[:, 2]

        trajs.append(traj)

    crossings = []

    ts_before = int(T_BEFORE / DT)
    ts_after = int(T_AFTER / DT)

    for traj in trajs:

        starts, onsets, peak_times, offsets, ends = \
            segment_by_threshold(traj['odors'], ANALYSIS_THRESHOLD)[0].T

        for start, peak_time, end in zip(starts, peak_times, ends):

            if not (H_MIN_PEAK <= traj['headings'][peak_time] < H_MAX_PEAK):

                continue

            crossing = np.nan * np.zeros((ts_before + ts_after, ))

            ts_before_crossing = peak_time - start
            ts_after_crossing = end - peak_time

            if ts_before_crossing >= ts_before:

                crossing[:ts_before] = traj['headings'][peak_time -
                                                        ts_before:peak_time]

            else:

                crossing[ts_before - ts_before_crossing:ts_before] = \
                    traj['headings'][start:peak_time]

            if ts_after_crossing >= ts_after:

                crossing[ts_before:] = traj['headings'][peak_time:peak_time +
                                                        ts_after]

            else:

                crossing[ts_before:ts_before + ts_after_crossing] = \
                    traj['headings'][peak_time:end]

            if SUBTRACT_PEAK_HEADING:

                crossing -= crossing[ts_before]

            crossings.append(crossing)

    crossings = np.array(crossings)

    t = np.arange(-ts_before, ts_after) * DT

    fig, ax = plt.subplots(1, 1, figsize=(8, 6), tight_layout=True)

    h_mean = np.nanmean(crossings, axis=0)
    h_sem = nansem(crossings, axis=0)

    ax.plot(t, crossings.T, lw=0.5, alpha=0.5, color='c', zorder=0)
    ax.plot(t, h_mean, lw=3, color='k')
    ax.fill_between(t, h_mean - h_sem, h_mean + h_sem, color='k', alpha=0.2)

    ax.axvline(0, ls='--', color='gray')

    ax.set_ylim(*Y_LIM)
    ax.set_xlabel('time since peak (s)')

    if SUBTRACT_PEAK_HEADING:

        ax.set_ylabel('change in heading (deg)')

    else:

        ax.set_ylabel('heading (deg)')

    set_font_size(ax, 16)

    return fig
def crossing_triggered_headings_early_late_vary_param(
        SEED, SAVE_FILE, N_TRAJS, DURATION, DT,
        TAU, NOISE, BIAS, HIT_INFLUENCE, SQRT_K_0,
        VARIABLE_PARAMS, BOUNDS,
        PL_CONC, PL_MEAN, PL_STD,
        H_MIN_PEAK, H_MAX_PEAK,
        X_MIN_PEAK, X_MAX_PEAK,
        EARLY_LESS_THAN,
        SUBTRACT_PEAK_HEADING, T_BEFORE, T_AFTER,
        T_INT_START, T_INT_END, AX_GRID):
    """
    Fly several agents through a simulated plume and plot their plume-crossing-triggered
    headings.
    """

    # try to open saved results

    if os.path.isfile(SAVE_FILE):

        print('Results file found. Loading results file.')
        results = np.load(SAVE_FILE)

    else:

        print('Results file not found. Running analysis...')
        np.random.seed(SEED)

        # build plume

        pl = GaussianLaminarPlume(PL_CONC, PL_MEAN, PL_STD)

        # loop over all parameter sets

        varying_params = []
        fixed_params = []

        early_late_heading_diffs_all = []
        early_late_heading_diffs_lb_all = []
        early_late_heading_diffs_ub_all = []

        for variable_params in VARIABLE_PARAMS:

            print('Variable params: {}'.format(variable_params))

            assert set(variable_params.keys()) == set(
                ['threshold', 'tau_memory', 'sqrt_k_s'])

            # identify which parameter is varying

            for key, vals in variable_params.items():

                if isinstance(vals, list):

                    varying_params.append((key, vals))
                    fixed_params.append((
                        [(k, v) for k, v in variable_params.items() if k != key]))

                    n_param_sets = len(vals)

                    break

            # make other parameters into lists so they can all be looped over nicely

            for key, vals in variable_params.items():

                if not isinstance(vals, list):

                    variable_params[key] = [vals for _ in range(n_param_sets)]

            early_late_heading_diffs = []
            early_late_heading_diffs_lb = []
            early_late_heading_diffs_ub = []

            for param_set_ctr in range(len(variable_params.values()[0])):

                threshold = variable_params['threshold'][param_set_ctr]
                hit_influence = HIT_INFLUENCE
                tau_memory = variable_params['tau_memory'][param_set_ctr]
                k_0 = np.array([
                    [SQRT_K_0 ** 2, 0],
                    [0, SQRT_K_0 ** 2],
                ])
                k_s = np.array([
                    [variable_params['sqrt_k_s'][param_set_ctr] ** 2, 0],
                    [0, variable_params['sqrt_k_s'][param_set_ctr] ** 2],
                ])

                # build tracking agent

                ag = CenterlineInferringAgent(
                    tau=TAU, noise=NOISE, bias=BIAS, threshold=threshold,
                    hit_trigger='peak', hit_influence=hit_influence, tau_memory=tau_memory,
                    k_0=k_0, k_s=k_s, bounds=BOUNDS)

                trajs = []

                for _ in range(N_TRAJS):

                    # choose random start position

                    start_pos = np.array([
                        np.random.uniform(*BOUNDS[0]),
                        np.random.uniform(*BOUNDS[1]),
                        np.random.uniform(*BOUNDS[2]),
                    ])

                    # make trajectory

                    traj = ag.track(plume=pl, start_pos=start_pos, duration=DURATION, dt=DT)

                    traj['headings'] = heading(traj['vs'])[:, 2]

                    trajs.append(traj)

                crossings_early = []
                crossings_late = []

                ts_before = int(T_BEFORE / DT)
                ts_after = int(T_AFTER / DT)

                for traj in trajs:

                    starts, onsets, peak_times, offsets, ends = \
                        segment_by_threshold(traj['odors'], threshold)[0].T

                    for ctr, (start, peak_time, end) in enumerate(zip(starts, peak_times, ends)):

                        if not (H_MIN_PEAK <= traj['headings'][peak_time] < H_MAX_PEAK):

                            continue

                        if not (X_MIN_PEAK <= traj['xs'][peak_time, 0] < X_MAX_PEAK):

                            continue

                        crossing = np.nan * np.zeros((ts_before + ts_after,))

                        ts_before_crossing = peak_time - start
                        ts_after_crossing = end - peak_time

                        if ts_before_crossing >= ts_before:

                            crossing[:ts_before] = traj['headings'][peak_time - ts_before:peak_time]

                        else:

                            crossing[ts_before - ts_before_crossing:ts_before] = \
                                traj['headings'][start:peak_time]

                        if ts_after_crossing >= ts_after:

                            crossing[ts_before:] = traj['headings'][peak_time:peak_time + ts_after]

                        else:

                            crossing[ts_before:ts_before + ts_after_crossing] = \
                                traj['headings'][peak_time:end]

                        if SUBTRACT_PEAK_HEADING:

                            crossing -= crossing[ts_before]

                        if ctr < EARLY_LESS_THAN:

                            crossings_early.append(crossing)

                        else:

                            crossings_late.append(crossing)

                crossings_early = np.array(crossings_early)
                crossings_late = np.array(crossings_late)

                t = np.arange(-ts_before, ts_after) * DT

                h_mean_early = np.nanmean(crossings_early, axis=0)
                h_mean_late = np.nanmean(crossings_late, axis=0)

                h_sem_early = nansem(crossings_early, axis=0)
                h_sem_late = nansem(crossings_late, axis=0)

                h_mean_diff = h_mean_late - h_mean_early

                h_mean_diff_lb = h_mean_late - h_sem_late - (h_mean_early + h_sem_early)
                h_mean_diff_ub = h_mean_late + h_sem_late - (h_mean_early - h_sem_early)

                early_late_heading_diff = \
                    h_mean_diff[(t > T_INT_START) * (t <= T_INT_END)].mean()
                early_late_heading_diff_lb = \
                    h_mean_diff_lb[(t > T_INT_START) * (t <= T_INT_END)].mean()
                early_late_heading_diff_ub = \
                    h_mean_diff_ub[(t > T_INT_START) * (t <= T_INT_END)].mean()

                early_late_heading_diffs.append(early_late_heading_diff)
                early_late_heading_diffs_lb.append(early_late_heading_diff_lb)
                early_late_heading_diffs_ub.append(early_late_heading_diff_ub)

            early_late_heading_diffs_all.append(np.array(early_late_heading_diffs))
            early_late_heading_diffs_lb_all.append(np.array(early_late_heading_diffs_lb))
            early_late_heading_diffs_ub_all.append(np.array(early_late_heading_diffs_ub))

        # save results

        results = np.array([
            {
                'varying_params': varying_params,
                'fixed_params': fixed_params,
                'early_late_heading_diffs_all': early_late_heading_diffs_all,
                'early_late_heading_diffs_lb_all': early_late_heading_diffs_lb_all,
                'early_late_heading_diffs_ub_all': early_late_heading_diffs_ub_all,
             }])
        np.save(SAVE_FILE, results)

    results = results[0]

    ## MAKE PLOTS

    fig_size = (5 * AX_GRID[1], 4 * AX_GRID[0])
    fig, axs = plt.subplots(*AX_GRID, figsize=fig_size, tight_layout=True)

    for ax_ctr in range(len(results['varying_params'])):

        ax = axs.flatten()[ax_ctr]

        ys_plot = results['early_late_heading_diffs_all'][ax_ctr]
        ys_err = [
            ys_plot - results['early_late_heading_diffs_lb_all'][ax_ctr],
            results['early_late_heading_diffs_ub_all'][ax_ctr] - ys_plot
        ]

        xs_name = results['varying_params'][ax_ctr][0]
        xs_plot = np.arange(len(ys_plot))

        ax.errorbar(
            xs_plot, ys_plot, yerr=ys_err, color='k', fmt='--o')

        ax.axhline(0, color='gray')

        if np.max(results['early_late_heading_diffs_ub_all'][ax_ctr]) > 0:

            y_range = np.max(results['early_late_heading_diffs_ub_all'][ax_ctr]) - \
                      np.min(results['early_late_heading_diffs_lb_all'][ax_ctr])

        else:

            y_range = -np.min(results['early_late_heading_diffs_lb_all'][ax_ctr])

        y_min = np.min(
            results['early_late_heading_diffs_lb_all'][ax_ctr]) - 0.1 * y_range
        y_max = max(np.max(
            results['early_late_heading_diffs_ub_all'][ax_ctr]), 0) + 0.1 * y_range

        ax.set_xlim(-1, len(ys_plot))
        ax.set_xticks(xs_plot)
        
        x_ticklabels = results['varying_params'][ax_ctr][1]
        
        if xs_name == 'threshold': 
            
            x_ticklabels = ['{0:.4f}'.format(xtl * (0.0476/526)) for xtl in x_ticklabels]
            
        ax.set_xticklabels(x_ticklabels)

        ax.set_ylim(y_min, y_max)

        if xs_name == 'tau_memory': x_label = 'tau_m (s)'
        elif xs_name == 'threshold': x_label = 'threshold (% ethanol)'
        else: x_label = xs_name
            
        ax.set_xlabel(x_label)
        ax.set_ylabel('mean heading difference\nfor late vs. early crossings')

    for ax in axs.flatten():

        set_font_size(ax, 16)

    return fig
            )

            if SUBTRACT_INITIAL_HEADING:

                headings -= headings[N_TIMESTEPS_BEFORE]

            crossings_time_series[crossing_ctr, :] = headings
            crossing_ctr += 1

        crossings_time_series = crossings_time_series[:crossing_ctr, :]

        print('{} total crossings'.format(len(crossings_time_series)))

        # get mean and sem
        crossings_mean = np.nanmean(crossings_time_series, axis=0)
        crossings_sem = stats.nansem(crossings_time_series, axis=0)
        crossings_std = np.nanstd(crossings_time_series, axis=0)

        # make plot
        handle, = axs[0, col_ctr].plot(time_vec, crossings_mean, lw=2, color=EXPT_COLORS[expt.id])
        axs[0, col_ctr].fill_between(
            time_vec, crossings_mean - crossings_std, crossings_mean + crossings_std,
            color=EXPT_COLORS[expt.id], alpha=0.3
        )

        if col_ctr == 0:
            wind_speed_handles.append(handle)

        # make querysets for early and late crossings
        crossings_early = crossings.filter(models.Crossing.crossing_number <= 2)
        crossings_late = crossings.filter(models.Crossing.crossing_number > 2)
def crossing_triggered_headings_early_late(
        SEED, N_TRAJS, DURATION, DT,
        TAU, NOISE, BIAS, AGENT_THRESHOLD,
        HIT_INFLUENCE, TAU_MEMORY,
        K_0, K_S, BOUNDS,
        PL_CONC, PL_MEAN, PL_STD,
        ANALYSIS_THRESHOLD,
        H_MIN_PEAK, H_MAX_PEAK,
        X_MIN_PEAK, X_MAX_PEAK,
        EARLY_LESS_THAN,
        SUBTRACT_PEAK_HEADING, T_BEFORE, T_AFTER):
    """
    Fly several agents through a simulated plume and plot their plume-crossing-triggered
    headings.
    """

    # build plume and agent

    pl = GaussianLaminarPlume(PL_CONC, PL_MEAN, PL_STD)

    k_0 = K_0 * np.eye(2)
    k_s = K_S * np.eye(2)

    ag = CenterlineInferringAgent(
        tau=TAU, noise=NOISE, bias=BIAS, threshold=AGENT_THRESHOLD,
        hit_trigger='peak', hit_influence=HIT_INFLUENCE,
        k_0=k_0, k_s=k_s, tau_memory=TAU_MEMORY, bounds=BOUNDS)

    # GENERATE TRAJECTORIES

    np.random.seed(SEED)

    trajs = []

    for _ in range(N_TRAJS):

        # choose random start position

        start_pos = np.array([
            np.random.uniform(*BOUNDS[0]),
            np.random.uniform(*BOUNDS[1]),
            np.random.uniform(*BOUNDS[2]),
        ])

        # make trajectory

        traj = ag.track(plume=pl, start_pos=start_pos, duration=DURATION, dt=DT)

        traj['headings'] = heading(traj['vs'])[:, 2]

        trajs.append(traj)

    # ANALYZE TRAJECTORIES

    n_crossings = []

    # collect early and late crossings

    crossings_early = []
    crossings_late = []

    ts_before = int(T_BEFORE / DT)
    ts_after = int(T_AFTER / DT)

    for traj in trajs:

        starts, onsets, peak_times, offsets, ends = \
            segment_by_threshold(traj['odors'], ANALYSIS_THRESHOLD)[0].T

        n_crossings.append(len(peak_times))

        for ctr, (start, peak_time, end) in enumerate(zip(starts, peak_times, ends)):

            if not (H_MIN_PEAK <= traj['headings'][peak_time] < H_MAX_PEAK):

                continue

            if not (X_MIN_PEAK <= traj['xs'][peak_time, 0] < X_MAX_PEAK):

                continue

            crossing = np.nan * np.zeros((ts_before + ts_after,))

            ts_before_crossing = peak_time - start
            ts_after_crossing = end - peak_time

            if ts_before_crossing >= ts_before:

                crossing[:ts_before] = traj['headings'][peak_time - ts_before:peak_time]

            else:

                crossing[ts_before - ts_before_crossing:ts_before] = \
                    traj['headings'][start:peak_time]

            if ts_after_crossing >= ts_after:

                crossing[ts_before:] = traj['headings'][peak_time:peak_time + ts_after]

            else:

                crossing[ts_before:ts_before + ts_after_crossing] = \
                    traj['headings'][peak_time:end]

            if SUBTRACT_PEAK_HEADING:

                crossing -= crossing[ts_before]

            if ctr < EARLY_LESS_THAN:

                crossings_early.append(crossing)

            else:

                crossings_late.append(crossing)

    n_crossings = np.array(n_crossings)

    crossings_early = np.array(crossings_early)
    crossings_late = np.array(crossings_late)

    t = np.arange(-ts_before, ts_after) * DT

    h_mean_early = np.nanmean(crossings_early, axis=0)
    h_sem_early = nansem(crossings_early, axis=0)

    h_mean_late = np.nanmean(crossings_late, axis=0)
    h_sem_late = nansem(crossings_late, axis=0)

    fig, axs = plt.figure(figsize=(15, 15), tight_layout=True), []

    axs.append(fig.add_subplot(3, 2, 1))
    axs.append(fig.add_subplot(3, 2, 2))

    handles = []

    try:

        handles.append(axs[0].plot(t, h_mean_early, lw=3, color='b', label='early')[0])
        axs[0].fill_between(t, h_mean_early - h_sem_early, h_mean_early + h_sem_early,
            color='b', alpha=0.2)

    except:

        pass

    try:

        handles.append(axs[0].plot(t, h_mean_late, lw=3, color='g', label='late')[0])
        axs[0].fill_between(t, h_mean_late - h_sem_late, h_mean_late + h_sem_late,
            color='g', alpha=0.2)

    except:

        pass

    axs[0].axvline(0, ls='--', color='gray')

    axs[0].set_xlabel('time since peak (s)')

    if SUBTRACT_PEAK_HEADING:

        axs[0].set_ylabel('change in heading (deg)')

    else:

        axs[0].set_ylabel('heading (deg)')

    axs[0].legend(handles=handles, fontsize=16)

    bin_min = -0.5
    bin_max = n_crossings.max() + 0.5

    bins = np.linspace(bin_min, bin_max, bin_max - bin_min + 1, endpoint=True)

    axs[1].hist(n_crossings, bins=bins, lw=0, normed=True)
    axs[1].set_xlim(bin_min, bin_max)

    axs[1].set_xlabel('number of crossings')
    axs[1].set_ylabel('proportion of trajectories')

    axs.append(fig.add_subplot(3, 1, 2))

    axs[2].plot(trajs[0]['xs'][:, 0], trajs[0]['xs'][:, 1])
    axs[2].axhline(0, color='gray', ls='--')

    axs[2].set_xlabel('x (m)')
    axs[2].set_ylabel('y (m)')

    axs.append(fig.add_subplot(3, 1, 3))

    all_xy = np.concatenate([traj['xs'][:, :2] for traj in trajs[:3000]], axis=0)
    x_bins = np.linspace(BOUNDS[0][0], BOUNDS[0][1], 66, endpoint=True)
    y_bins = np.linspace(BOUNDS[1][0], BOUNDS[1][1], 30, endpoint=True)

    axs[3].hist2d(all_xy[:, 0], all_xy[:, 1], bins=(x_bins, y_bins))

    axs[3].set_xlabel('x (m)')
    axs[3].set_ylabel('y (m)')

    for ax in axs:

        set_font_size(ax, 20)

    return fig
def crossing_triggered_headings_all(
        SEED,
        N_TRAJS, DURATION, DT,
        TAU, NOISE, BIAS, AGENT_THRESHOLD,
        HIT_INFLUENCE, TAU_MEMORY, K_0, K_S,
        BOUNDS,
        PL_CONC, PL_MEAN, PL_STD,
        ANALYSIS_THRESHOLD, H_MIN_PEAK, H_MAX_PEAK,
        SUBTRACT_PEAK_HEADING, T_BEFORE, T_AFTER, Y_LIM):
    """
    Fly several agents through a simulated plume and plot their plume-crossing-triggered
    headings.
    """

    # build plume and agent

    pl = GaussianLaminarPlume(PL_CONC, PL_MEAN, PL_STD)

    k_0 = K_0 * np.eye(2)
    k_s = K_S * np.eye(2)

    ag = CenterlineInferringAgent(
        tau=TAU, noise=NOISE, bias=BIAS, threshold=AGENT_THRESHOLD,
        hit_trigger='peak', hit_influence=HIT_INFLUENCE,
        k_0=k_0, k_s=k_s, tau_memory=TAU_MEMORY, bounds=BOUNDS)

    # generate trajectories

    np.random.seed(SEED)

    trajs = []

    for _ in range(N_TRAJS):

        # choose random start position

        start_pos = np.array([
            np.random.uniform(*BOUNDS[0]),
            np.random.uniform(*BOUNDS[1]),
            np.random.uniform(*BOUNDS[2]),
        ])

        # make trajectory

        traj = ag.track(plume=pl, start_pos=start_pos, duration=DURATION, dt=DT)

        traj['headings'] = heading(traj['vs'])[:, 2]

        trajs.append(traj)

    crossings = []

    ts_before = int(T_BEFORE / DT)
    ts_after = int(T_AFTER / DT)

    for traj in trajs:

        starts, onsets, peak_times, offsets, ends = \
            segment_by_threshold(traj['odors'], ANALYSIS_THRESHOLD)[0].T

        for start, peak_time, end in zip(starts, peak_times, ends):

            if not (H_MIN_PEAK <= traj['headings'][peak_time] < H_MAX_PEAK):

                continue

            crossing = np.nan * np.zeros((ts_before + ts_after,))

            ts_before_crossing = peak_time - start
            ts_after_crossing = end - peak_time

            if ts_before_crossing >= ts_before:

                crossing[:ts_before] = traj['headings'][peak_time - ts_before:peak_time]

            else:

                crossing[ts_before - ts_before_crossing:ts_before] = \
                    traj['headings'][start:peak_time]

            if ts_after_crossing >= ts_after:

                crossing[ts_before:] = traj['headings'][peak_time:peak_time + ts_after]

            else:

                crossing[ts_before:ts_before + ts_after_crossing] = \
                    traj['headings'][peak_time:end]

            if SUBTRACT_PEAK_HEADING:

                crossing -= crossing[ts_before]

            crossings.append(crossing)

    crossings = np.array(crossings)

    t = np.arange(-ts_before, ts_after) * DT

    fig, ax = plt.subplots(1, 1, figsize=(8, 6), tight_layout=True)

    h_mean = np.nanmean(crossings, axis=0)
    h_sem = nansem(crossings, axis=0)

    ax.plot(t, crossings.T, lw=0.5, alpha=0.5, color='c', zorder=0)
    ax.plot(t, h_mean, lw=3, color='k')
    ax.fill_between(t, h_mean - h_sem, h_mean + h_sem, color='k', alpha=0.2)

    ax.axvline(0, ls='--', color='gray')

    ax.set_ylim(*Y_LIM)
    ax.set_xlabel('time since peak (s)')

    if SUBTRACT_PEAK_HEADING:

        ax.set_ylabel('change in heading (deg)')

    else:

        ax.set_ylabel('heading (deg)')

    set_font_size(ax, 16)

    return fig
Exemple #14
0
def infotaxis_analysis(
        WIND_TUNNEL_CG_IDS,
        INFOTAXIS_WIND_SPEED_CG_IDS,
        MAX_CROSSINGS,
        INFOTAXIS_HISTORY_DEPENDENCE_CG_IDS,
        MAX_CROSSINGS_EARLY,
        X_0_MIN, X_0_MAX, H_0_MIN, H_0_MAX,
        X_0_MIN_SIM, X_0_MAX_SIM,
        X_0_MIN_SIM_HISTORY, X_0_MAX_SIM_HISTORY,
        T_BEFORE_EXPT, T_AFTER_EXPT,
        TS_BEFORE_SIM, TS_AFTER_SIM, HEADING_SMOOTHING_SIM,
        HEAT_MAP_EXPT_ID, HEAT_MAP_SIM_ID,
        N_HEAT_MAP_TRAJS,
        X_BINS, Y_BINS,
        FIG_SIZE, FONT_SIZE,
        EXPT_LABELS,
        EXPT_COLORS,
        SIM_LABELS):
    """
    Show infotaxis-generated trajectories alongside empirical trajectories. Show wind-speed
    dependence and history dependence.
    """

    from db_api.infotaxis import models as models_infotaxis
    from db_api.infotaxis.connect import session as session_infotaxis

    ts_before_expt = int(round(T_BEFORE_EXPT / DT))
    ts_after_expt = int(round(T_AFTER_EXPT / DT))

    headings = {}

    # get headings for wind tunnel plume crossings

    headings['wind_tunnel'] = {}

    for cg_id in WIND_TUNNEL_CG_IDS:

        crossings_all = session.query(models.Crossing).filter_by(crossing_group_id=cg_id).all()

        headings['wind_tunnel'][cg_id] = []

        cr_ctr = 0

        for crossing in crossings_all:

            if cr_ctr >= MAX_CROSSINGS:

                break

            # skip this crossing if it doesn't meet our inclusion criteria

            x_0 = crossing.feature_set_basic.position_x_peak
            h_0 = crossing.feature_set_basic.heading_xyz_peak

            if not (X_0_MIN <= x_0 <= X_0_MAX):

                continue

            if not (H_0_MIN <= h_0 <= H_0_MAX):

                continue

            # store crossing heading

            temp = crossing.timepoint_field(
                session, 'heading_xyz', -ts_before_expt, ts_after_expt - 1,
                'peak', 'peak', nan_pad=True)

            # subtract initial heading

            temp -= temp[ts_before_expt]

            headings['wind_tunnel'][cg_id].append(temp)

            cr_ctr += 1

        headings['wind_tunnel'][cg_id] = np.array(headings['wind_tunnel'][cg_id])

    # get headings from infotaxis plume crossings

    headings['infotaxis'] = {}

    for cg_id in INFOTAXIS_WIND_SPEED_CG_IDS:

        crossings_all = list(session_infotaxis.query(models_infotaxis.Crossing).filter_by(
            crossing_group_id=cg_id).all())

        print('{} crossings for infotaxis crossing group: "{}"'.format(
            len(crossings_all), cg_id))

        headings['infotaxis'][cg_id] = []

        cr_ctr = 0

        for crossing in crossings_all:

            if cr_ctr >= MAX_CROSSINGS:

                break

            # skip this crossing if it doesn't meet our inclusion criteria

            x_0 = crossing.feature_set_basic.position_x_peak
            h_0 = crossing.feature_set_basic.heading_xyz_peak

            if not (X_0_MIN_SIM <= x_0 <= X_0_MAX_SIM):

                continue

            if not (H_0_MIN <= h_0 <= H_0_MAX):

                continue

            # store crossing heading

            temp = crossing.timepoint_field(
                session_infotaxis, 'hxyz', -TS_BEFORE_SIM, TS_AFTER_SIM - 1,
                'peak', 'peak', nan_pad=True)

            temp[~np.isnan(temp)] = gaussian_filter1d(
                temp[~np.isnan(temp)], HEADING_SMOOTHING_SIM)

            # subtract initial heading and store result

            temp -= temp[TS_BEFORE_SIM]

            headings['infotaxis'][cg_id].append(temp)

            cr_ctr += 1

        headings['infotaxis'][cg_id] = np.array(headings['infotaxis'][cg_id])

    # get history dependences for infotaxis simulations

    headings['it_hist_dependence'] = {}

    for cg_id in INFOTAXIS_HISTORY_DEPENDENCE_CG_IDS:

        crossings_all = list(session_infotaxis.query(models_infotaxis.Crossing).filter_by(
            crossing_group_id=cg_id).all())

        headings['it_hist_dependence'][cg_id] = {'early': [], 'late': []}

        cr_ctr = 0

        for crossing in crossings_all:

            if cr_ctr >= MAX_CROSSINGS:

                break

            # skip this crossing if it doesn't meet our inclusion criteria

            x_0 = crossing.feature_set_basic.position_x_peak
            h_0 = crossing.feature_set_basic.heading_xyz_peak

            if not (X_0_MIN_SIM_HISTORY <= x_0 <= X_0_MAX_SIM_HISTORY):

                continue

            if not (H_0_MIN <= h_0 <= H_0_MAX):

                continue

            # store crossing heading

            temp = crossing.timepoint_field(
                session_infotaxis, 'hxyz', -TS_BEFORE_SIM, TS_AFTER_SIM - 1,
                'peak', 'peak', nan_pad=True)

            temp[~np.isnan(temp)] = gaussian_filter1d(
                temp[~np.isnan(temp)], HEADING_SMOOTHING_SIM)

            # subtract initial heading

            temp -= temp[TS_BEFORE_SIM]

            # store according to its crossing number

            if crossing.crossing_number <= MAX_CROSSINGS_EARLY:

                headings['it_hist_dependence'][cg_id]['early'].append(temp)

            elif crossing.crossing_number > MAX_CROSSINGS_EARLY:

                headings['it_hist_dependence'][cg_id]['late'].append(temp)

            else:

                raise Exception('crossing number is not early or late for crossing {}'.format(
                    crossing.id))

            cr_ctr += 1

    headings['it_hist_dependence'][cg_id]['early'] = np.array(
        headings['it_hist_dependence'][cg_id]['early'])

    headings['it_hist_dependence'][cg_id]['late'] = np.array(
        headings['it_hist_dependence'][cg_id]['late'])

    # get heatmaps

    if N_HEAT_MAP_TRAJS:

        trajs_expt = session.query(models.Trajectory).\
            filter_by(experiment_id=HEAT_MAP_EXPT_ID, odor_state='on').limit(N_HEAT_MAP_TRAJS)
        trials_sim = session_infotaxis.query(models_infotaxis.Trial).\
            filter_by(simulation_id=HEAT_MAP_SIM_ID).limit(N_HEAT_MAP_TRAJS)

    else:

        trajs_expt = session.query(models.Trajectory).\
            filter_by(experiment_id=HEAT_MAP_EXPT_ID, odor_state='on')
        trials_sim = session_infotaxis.query(models_infotaxis.Trial).\
            filter_by(simulation_id=HEAT_MAP_SIM_ID)

    expt_xs = []
    expt_ys = []

    sim_xs = []
    sim_ys = []

    for traj in trajs_expt:

        expt_xs.append(traj.timepoint_field(session, 'position_x'))
        expt_ys.append(traj.timepoint_field(session, 'position_y'))

    for trial in trials_sim:

        sim_xs.append(trial.timepoint_field(session_infotaxis, 'xidx'))
        sim_ys.append(trial.timepoint_field(session_infotaxis, 'yidx'))

    expt_xs = np.concatenate(expt_xs)
    expt_ys = np.concatenate(expt_ys)

    sim_xs = np.concatenate(sim_xs) * 0.02 - 0.3
    sim_ys = np.concatenate(sim_ys) * 0.02 - 0.15

    ## MAKE PLOTS

    fig, axs = plt.figure(figsize=FIG_SIZE, tight_layout=True), []

    axs.append(fig.add_subplot(4, 3, 1))
    axs.append(fig.add_subplot(4, 3, 2, sharey=axs[0]))

    # plot wind-speed dependence of wind tunnel trajectories

    t = np.arange(-ts_before_expt, ts_after_expt) * DT

    handles = []

    for cg_id in WIND_TUNNEL_CG_IDS:

        label = EXPT_LABELS[cg_id]
        color = EXPT_COLORS[cg_id]

        headings_mean = np.nanmean(headings['wind_tunnel'][cg_id], axis=0)
        headings_sem = stats.nansem(headings['wind_tunnel'][cg_id], axis=0)

        # plot mean and sem

        handles.append(
            axs[0].plot(t, headings_mean, lw=3, color=color, zorder=1, label=label)[0])
        axs[0].fill_between(
            t, headings_mean - headings_sem, headings_mean + headings_sem,
            color=color, alpha=0.2)

    axs[0].set_xlabel('time since odor peak (s)')
    axs[0].set_ylabel('$\Delta$ heading (degrees)')
    axs[0].set_title('experimental data\n(wind speed comparison)')

    axs[0].legend(handles=handles, loc='best')

    t = np.arange(-TS_BEFORE_SIM, TS_AFTER_SIM)

    for cg_id, wt_cg_id in zip(INFOTAXIS_WIND_SPEED_CG_IDS, WIND_TUNNEL_CG_IDS):

        label = EXPT_LABELS[wt_cg_id]
        color = EXPT_COLORS[wt_cg_id]

        headings_mean = np.nanmean(headings['infotaxis'][cg_id], axis=0)
        headings_sem = stats.nansem(headings['infotaxis'][cg_id], axis=0)

        # plot mean and sem

        axs[1].plot(t, headings_mean, lw=3, color=color, zorder=1, label=label)
        axs[1].fill_between(
            t, headings_mean - headings_sem, headings_mean + headings_sem,
            color=color, alpha=0.2)

    axs[1].set_xlabel('time steps since odor peak (s)')
    axs[1].set_title('infotaxis simulations\n(wind speed comparison)')

    # add axes for infotaxis history dependence and make plots

    [axs.append(fig.add_subplot(4, 3, 3 + ctr)) for ctr in range(4)]

    for (ax, cg_id) in zip(axs[-4:], INFOTAXIS_HISTORY_DEPENDENCE_CG_IDS):

        mean_early = np.nanmean(headings['it_hist_dependence'][cg_id]['early'], axis=0)
        sem_early = stats.nansem(headings['it_hist_dependence'][cg_id]['early'], axis=0)

        mean_late = np.nanmean(headings['it_hist_dependence'][cg_id]['late'], axis=0)
        sem_late = stats.nansem(headings['it_hist_dependence'][cg_id]['late'], axis=0)

        # plot means and stds

        try:

            handle_early = ax.plot(t, mean_early, lw=3, color='b', zorder=0, label='early')[0]
            ax.fill_between(
                t, mean_early - sem_early, mean_early + sem_early,
                color='b', alpha=0.2)

        except:

            pass

        try:

            handle_late = ax.plot(t, mean_late, lw=3, color='g', zorder=0, label='late')[0]
            ax.fill_between(
                t, mean_late - sem_late, mean_late + sem_late,
                color='g', alpha=0.2)

        except:

            pass

        ax.set_xlabel('time steps since odor peak (s)')
        ax.set_title(SIM_LABELS[cg_id])

        try:

            ax.legend(handles=[handle_early, handle_late])

        except:

            pass

    axs[3].set_ylabel('$\Delta$ heading (degrees)')

    # plot heat maps

    axs.append(fig.add_subplot(4, 1, 3))
    axs.append(fig.add_subplot(4, 1, 4))

    axs[6].hist2d(expt_xs, expt_ys, bins=(X_BINS, Y_BINS))
    axs[7].hist2d(sim_xs, sim_ys, bins=(X_BINS, Y_BINS))

    axs[6].set_ylabel('y (m)')
    axs[7].set_ylabel('y (m)')
    axs[7].set_xlabel('x (m)')

    axs[6].set_title('experimental data (fly 0.4 m/s)')
    axs[7].set_title('infotaxis simulation')

    for ax in axs:

        set_fontsize(ax, FONT_SIZE)

    return fig
Exemple #15
0
def early_vs_late_heading_timecourse(
        CROSSING_GROUP_IDS, CROSSING_GROUP_LABELS,
        X_0_MIN, X_0_MAX, H_0_MIN, H_0_MAX,
        MAX_CROSSINGS_EARLY, SUBTRACT_INITIAL_HEADING,
        T_BEFORE, T_AFTER,
        AX_SIZE, AX_GRID, EARLY_LATE_COLORS, ALPHA,
        P_VAL_COLOR, P_VAL_Y_LIM, LEGEND_CROSSING_GROUP_ID,
        X_0_BINS, FONT_SIZE):
    """
    Show early vs. late headings for different experiments, along with a plot of the
    p-values for the difference between the two means.
    """

    # convert times to time steps

    ts_before = int(round(T_BEFORE / DT))
    ts_after = int(round(T_AFTER / DT))

    # loop over crossing groups

    x_0s_dict = {}
    headings_dict = {}
    p_vals_dict = {}

    for cg_id in CROSSING_GROUP_IDS:

        # get crossing group

        crossing_group = session.query(models.CrossingGroup).filter_by(id=cg_id).first()

        # get early and late crossings

        crossings_dict = {}
        crossings_all = session.query(models.Crossing).filter_by(crossing_group=crossing_group)
        crossings_dict['early'] = crossings_all.filter(models.Crossing.crossing_number <= MAX_CROSSINGS_EARLY)
        crossings_dict['late'] = crossings_all.filter(models.Crossing.crossing_number > MAX_CROSSINGS_EARLY)

        x_0s_dict[cg_id] = {}
        headings_dict[cg_id] = {}

        for label in ['early', 'late']:

            x_0s = []
            headings = []

            # get all initial headings, initial xs, peak concentrations, and heading time-series

            for crossing in crossings_dict[label]:

                # throw away crossings that do not meet trigger criteria

                x_0 = getattr(crossing.feature_set_basic, 'position_x_{}'.format('peak'))
                h_0 = getattr(crossing.feature_set_basic, 'heading_xyz_{}'.format('peak'))

                if not (X_0_MIN <= x_0 <= X_0_MAX):

                    continue

                if not (H_0_MIN <= h_0 <= H_0_MAX):

                    continue

                # store x_0 (uw/dw position)

                x_0s.append(x_0)

                # get and store headings

                temp = crossing.timepoint_field(
                    session, 'heading_xyz', -ts_before, ts_after - 1,
                    'peak', 'peak', nan_pad=True)

                # subtract initial heading if desired

                if SUBTRACT_INITIAL_HEADING:

                    temp -= temp[ts_before]

                # store headings

                headings.append(temp)

            x_0s_dict[cg_id][label] = np.array(x_0s)
            headings_dict[cg_id][label] = np.array(headings)

        # loop through all time points and calculate p-value (ks-test) between early and late

        p_vals = []

        for t_step in range(ts_before + ts_after):

            early_with_nans = headings_dict[cg_id]['early'][:, t_step]
            late_with_nans = headings_dict[cg_id]['late'][:, t_step]

            early_no_nans = early_with_nans[~np.isnan(early_with_nans)]
            late_no_nans = late_with_nans[~np.isnan(late_with_nans)]

            p_vals.append(ks_2samp(early_no_nans, late_no_nans)[1])

        p_vals_dict[cg_id] = p_vals


    ## MAKE PLOTS

    t = np.arange(-ts_before, ts_after) * DT

    # history-dependence

    fig_size = (AX_SIZE[0] * AX_GRID[1], AX_SIZE[1] * AX_GRID[0])

    fig_0, axs_0 = plt.subplots(*AX_GRID, figsize=fig_size, tight_layout=True)

    for cg_id, ax in zip(CROSSING_GROUP_IDS, axs_0.flat):

        # get mean and sem of headings for early and late groups

        handles = []

        for label, color in EARLY_LATE_COLORS.items():

            headings_mean = np.nanmean(headings_dict[cg_id][label], axis=0)
            headings_sem = stats.nansem(headings_dict[cg_id][label], axis=0)

            handles.append(ax.plot(t, headings_mean, color=color, lw=2, label=label, zorder=1)[0])
            ax.fill_between(
                t, headings_mean - headings_sem, headings_mean + headings_sem,
                color=color, alpha=ALPHA, zorder=1)

        ax.set_xlabel('time since crossing (s)')

        if SUBTRACT_INITIAL_HEADING:

            ax.set_ylabel('$\Delta$ heading (deg.)')

        else:

            ax.set_ylabel('heading (deg.)')

        ax.set_title(CROSSING_GROUP_LABELS[cg_id])

        if cg_id == LEGEND_CROSSING_GROUP_ID:

            ax.legend(handles=handles, loc='upper right')

        set_fontsize(ax, FONT_SIZE)

        # plot p-value

        ax_twin = ax.twinx()

        ax_twin.plot(t, p_vals_dict[cg_id], color=P_VAL_COLOR, lw=2, ls='--', zorder=0)

        ax_twin.axhline(0.05, ls='-', lw=2, color='gray')

        ax_twin.set_ylim(*P_VAL_Y_LIM)
        ax_twin.set_ylabel('p-value (KS test)', fontsize=FONT_SIZE)

        set_fontsize(ax_twin, FONT_SIZE)

    # position histograms

    bincs = 0.5 * (X_0_BINS[:-1] + X_0_BINS[1:])

    fig_1, axs_1 = plt.subplots(*AX_GRID, figsize=fig_size, tight_layout=True)

    for cg_id, ax in zip(CROSSING_GROUP_IDS, axs_1.flat):

        # create early and late histograms

        handles = []

        for label, color in EARLY_LATE_COLORS.items():

            probs = np.histogram(x_0s_dict[cg_id][label], bins=X_0_BINS, normed=True)[0]

            handles.append(ax.plot(100*bincs, probs, lw=2, color=color, label=label)[0])

        p_val = ks_2samp(x_0s_dict[cg_id]['early'], x_0s_dict[cg_id]['late'])[1]
        x_0_mean_diff = x_0s_dict[cg_id]['late'].mean() - x_0s_dict[cg_id]['early'].mean()

        ax.set_xlabel(r'$x_0$ (cm)')
        ax.set_ylabel('proportion of\ncrossings')

        title = '{0} ($\Delta mean(x_0)$ = {1:10.2} cm) \n(p = {2:10.5f} [KS test])'.format(
            CROSSING_GROUP_LABELS[cg_id], 100 * x_0_mean_diff, p_val)

        ax.set_title(title)

        ax.legend(handles=handles, loc='best')

        set_fontsize(ax, FONT_SIZE)

    return fig_0, fig_1
Exemple #16
0
def example_traj_and_crossings(
        SEED,
        EXPT_ID,
        TRAJ_NUMBER,
        TRAJ_START_TP, TRAJ_END_TP,
        CROSSING_GROUP,
        N_CROSSINGS,
        X_0_MIN, X_0_MAX, H_0_MIN, H_0_MAX,
        MIN_PEAK_CONC,
        TS_BEFORE_3D, TS_AFTER_3D,
        TS_BEFORE_HEADING, TS_AFTER_HEADING,
        FIG_SIZE,
        SCATTER_SIZE,
        CYL_STDS, CYL_COLOR, CYL_ALPHA,
        EXPT_LABEL,
        FONT_SIZE):
    """
    Show an example trajectory through a wind tunnel plume with the crossings marked.
    Show many crossings overlaid on the plume in 3D and show the mean peak-triggered heading
    with its SEM as well as many individual examples.
    """

    if isinstance(TRAJ_NUMBER, int):

        trajs = session.query(models.Trajectory).filter_by(
            experiment_id=EXPT_ID, odor_state='on', clean=True).all()

        traj = list(trajs)[TRAJ_NUMBER]

    else:

        traj = session.query(models.Trajectory).filter_by(id=TRAJ_NUMBER).first()

    # get plottable quantities

    x_traj, y_traj, z_traj = traj.positions(session).T[:, TRAJ_START_TP:TRAJ_END_TP]
    c_traj = traj.odors(session)[TRAJ_START_TP:TRAJ_END_TP]

    # get several random crossings

    crossings_all = session.query(models.Crossing).filter_by(
        crossing_group_id=CROSSING_GROUP).filter(models.Crossing.max_odor > MIN_PEAK_CONC).all()
    crossings_all = list(crossings_all)

    np.random.seed(SEED)

    plot_idxs = np.random.permutation(len(crossings_all))

    crossing_examples = []

    crossing_ctr = 0

    headings = []

    for idx in plot_idxs:

        crossing = crossings_all[idx]

        # throw away crossings that do not meet trigger criteria

        x_0 = getattr(crossing.feature_set_basic, 'position_x_{}'.format('peak'))

        if not (X_0_MIN <= x_0 <= X_0_MAX):

            continue

        h_0 = getattr(crossing.feature_set_basic, 'heading_xyz_{}'.format('peak'))

        if not (H_0_MIN <= h_0 <= H_0_MAX):

            continue

        # store example crossing if desired

        if crossing_ctr < N_CROSSINGS:

            crossing_dict = {}

            for field in ['position_x', 'position_y', 'position_z', 'odor']:

                crossing_dict[field] = crossing.timepoint_field(
                    session, field, -TS_BEFORE_3D, TS_AFTER_3D, 'peak', 'peak')

            crossing_dict['heading'] = crossing.timepoint_field(
                session, 'heading_xyz', -TS_BEFORE_HEADING, TS_AFTER_HEADING - 1,
                'peak', 'peak', nan_pad=True)

            crossing_examples.append(crossing_dict)

        # store crossing heading

        temp = crossing.timepoint_field(
            session, 'heading_xyz', -TS_BEFORE_HEADING, TS_AFTER_HEADING - 1,
            'peak', 'peak', nan_pad=True)

        headings.append(temp)

        # increment crossing ctr

        crossing_ctr += 1

    headings = np.array(headings)


    ## MAKE PLOTS

    fig, axs = plt.figure(figsize=FIG_SIZE, tight_layout=True), []

    #  plot example trajectory

    axs.append(fig.add_subplot(3, 1, 1, projection='3d'))

    # overlay plume cylinder

    CYL_MEAN_Y = PLUME_PARAMS_DICT[EXPT_ID]['ymean']
    CYL_MEAN_Z = PLUME_PARAMS_DICT[EXPT_ID]['zmean']

    CYL_SCALE_Y = CYL_STDS * PLUME_PARAMS_DICT[EXPT_ID]['ystd']
    CYL_SCALE_Z = CYL_STDS * PLUME_PARAMS_DICT[EXPT_ID]['zstd']

    MAX_CONC = PLUME_PARAMS_DICT[EXPT_ID]['max_conc']

    y = np.linspace(-1, 1, 100, endpoint=True)
    x = np.linspace(-0.3, 1, 5, endpoint=True)
    yy, xx = np.meshgrid(y, x)
    zz = np.sqrt(1 - yy ** 2)

    yy = CYL_SCALE_Y * yy + CYL_MEAN_Y
    zz_top = CYL_SCALE_Z * zz + CYL_MEAN_Z
    zz_bottom = -CYL_SCALE_Z * zz + CYL_MEAN_Z
    rstride = 20
    cstride = 10

    axs[0].plot_surface(
        xx, yy, zz_top, lw=0,
        color=CYL_COLOR, alpha=CYL_ALPHA, rstride=rstride, cstride=cstride)

    axs[0].plot_surface(
        xx, yy, zz_bottom, lw=0,
        color=CYL_COLOR, alpha=CYL_ALPHA, rstride=rstride, cstride=cstride)

    axs[0].scatter(
        x_traj, y_traj, z_traj, c=c_traj, s=SCATTER_SIZE,
        vmin=0, vmax=MAX_CONC/2, cmap=cmx.hot, lw=0, alpha=1)

    axs[0].set_xlim(-0.3, 1)
    axs[0].set_ylim(-0.15, 0.15)
    axs[0].set_zlim(-0.15, 0.15)

    axs[0].set_xticks([-0.3, 1.])
    axs[0].set_yticks([-0.15, 0.15])
    axs[0].set_zticks([-0.15, 0.15])

    axs[0].set_xticklabels([-30, 100])
    axs[0].set_yticklabels([-15, 15])
    axs[0].set_zticklabels([-15, 15])

    axs[0].set_xlabel('x (cm)')
    axs[0].set_ylabel('y (cm)')
    axs[0].set_zlabel('z (cm)')

    # plot several crossings

    axs.append(fig.add_subplot(3, 1, 2, projection='3d'))

    # overlay plume cylinder

    axs[1].plot_surface(
        xx, yy, zz_top, lw=0,
        color=CYL_COLOR, alpha=CYL_ALPHA, rstride=rstride, cstride=cstride)

    axs[1].plot_surface(
        xx, yy, zz_bottom, lw=0,
        color=CYL_COLOR, alpha=CYL_ALPHA, rstride=rstride, cstride=cstride)

    # plot crossings

    for crossing in crossing_examples:

        axs[1].scatter(
            crossing['position_x'], crossing['position_y'], crossing['position_z'],
            c=crossing['odor'], s=SCATTER_SIZE,
            vmin=0, vmax=MAX_CONC / 2, cmap=cmx.hot, lw=0, alpha=1)

    axs[1].set_xlim(-0.3, 1)
    axs[1].set_ylim(-0.15, 0.15)
    axs[1].set_zlim(-0.15, 0.15)

    axs[1].set_xticks([-0.3, 1.])
    axs[1].set_yticks([-0.15, 0.15])
    axs[1].set_zticks([-0.15, 0.15])

    axs[1].set_xticklabels([-30, 100])
    axs[1].set_yticklabels([-15, 15])
    axs[1].set_zticklabels([-15, 15])

    axs[1].set_xlabel('x (cm)')
    axs[1].set_ylabel('y (cm)')
    axs[1].set_zlabel('z (cm)')

    # plot headings

    axs.append(fig.add_subplot(3, 2, 6))

    t = np.arange(-TS_BEFORE_HEADING, TS_AFTER_HEADING) * DT

    headings_mean = np.nanmean(headings, axis=0)
    headings_std = np.nanstd(headings, axis=0)
    headings_sem = stats.nansem(headings, axis=0)

    # plot example crossings

    for crossing in crossing_examples:

        axs[2].plot(t, crossing['heading'], lw=1, color='k', alpha=0.5, zorder=-1)

    # plot mean, sem, and std

    axs[2].plot(t, headings_mean, lw=3, color='k', zorder=1)
    axs[2].plot(t, headings_mean - headings_std, lw=3, ls='--', color='k', zorder=1)
    axs[2].plot(t, headings_mean + headings_std, lw=3, ls='--', color='k', zorder=1)
    axs[2].fill_between(
        t, headings_mean - headings_sem, headings_mean + headings_sem,
        color='k', alpha=0.2)

    axs[2].set_xlabel('time since crossing (s)')
    axs[2].set_ylabel('heading (degrees)')

    for ax in axs:

        set_fontsize(ax, FONT_SIZE)

    return fig