def sub_rot_at_max_elev(shoulder_traj: ShoulderTrajInterp, traj_def: str,
                        decomp_method: str, sub_rot: Union[int, None],
                        norm_by: str) -> np.ndarray:
    """Extract value at max HT elevation and normalize the specified subrotation given an interpolated shoulder
    trajectory (shoulder_traj), joint (traj_def, e.g. ht, gh, st), interpolation (y_def, e.g. common_fine_up),
    decomposition method (decomp_method, e.g. euler.ht_isb), subrotation (sub_rot, e.g. 0, 1, 2, None), and
    normalization section (norm_by, e.g. traj, up, down, ...)."""
    y = extract_sub_rot(shoulder_traj, traj_def, 'up', decomp_method,
                        sub_rot)[-1]
    if norm_by is None:
        return y

    # first extract ht, gh, or st
    joint_traj = getattr(shoulder_traj, traj_def)
    if decomp_method == 'true_axial_rot':
        if norm_by == 'traj':
            y0 = getattr(joint_traj, 'axial_rot')[0]
        else:
            y0 = getattr(joint_traj, 'axial_rot_' + norm_by)[0]
    elif decomp_method == 'induced_axial_rot':
        if norm_by == 'traj':
            y0 = getattr(joint_traj, 'induced_axial_rot')[0, sub_rot]
        else:
            y0 = getattr(joint_traj, 'induced_axial_rot_' + norm_by)[0,
                                                                     sub_rot]
    else:
        y0 = rgetattr(getattr(joint_traj, norm_by), decomp_method)[0, sub_rot]

    return y - y0
def extract_sub_rot_diff(shoulder_traj, traj_def, y_def, decomp_method,
                         sub_rot):
    # first extract ht, gh, or st
    joint_traj = getattr(shoulder_traj, traj_def)
    # Then extract the decomp_method. Note that each JointTrajectory actually computes separate scalar interpolation for
    # true_axial_rot (that's why I don't access the PoseTrajectory below) because true_axial_rot is path dependent so
    # it doesn't make sense to compute it on a trajectory that starts at 25 degrees (for example)
    if decomp_method == 'true_axial_rot':
        y_up = getattr(joint_traj, 'axial_rot_' + y_def + '_up')
        y_down = getattr(joint_traj, 'axial_rot_' + y_def + '_down')
    elif decomp_method == 'induced_axial_rot':
        y_up = getattr(joint_traj,
                       'induced_axial_rot_' + y_def + '_up')[:, sub_rot]
        y_down = getattr(joint_traj,
                         'induced_axial_rot_' + y_def + '_down')[:, sub_rot]
    else:
        y_up = rgetattr(getattr(joint_traj, y_def + '_up'),
                        decomp_method)[:, sub_rot]
        y_down = rgetattr(getattr(joint_traj, y_def + '_down'),
                          decomp_method)[:, sub_rot]
    return y_up - y_down
def extract_sub_rot(shoulder_traj: ShoulderTrajInterp, traj_def: str,
                    y_def: str, decomp_method: str,
                    sub_rot: Union[int, None]) -> np.ndarray:
    """Extract the specified subrotation given an interpolated shoulder trajectory (shoulder_traj),
    joint (traj_def, e.g. ht, gh, st), interpolation (y_def, e.g. common_fine_up),
    decomposition method (decomp_method, e.g. euler.ht_isb), and subrotation (sub_rot, e.g. 0, 1, 2, None)."""
    # first extract ht, gh, or st
    joint_traj = getattr(shoulder_traj, traj_def)
    # Then extract the decomp_method. Note that each JointTrajectory actually computes separate scalar interpolation for
    # true_axial_rot (that's why I don't access the PoseTrajectory below) because true_axial_rot is path dependent so
    # it doesn't make sense to compute it on a trajectory that starts at 25 degrees (for example)
    if decomp_method == 'true_axial_rot':
        y = getattr(joint_traj, 'axial_rot_' + y_def)
    elif decomp_method == 'induced_axial_rot':
        y = getattr(joint_traj, 'induced_axial_rot_' + y_def)[:, sub_rot]
    else:
        y = rgetattr(getattr(joint_traj, y_def), decomp_method)[:, sub_rot]
    return y
def true_axial_analysis(df_row, traj_def, euler_def, path_fnc, add_axial_rot_fnc):
    traj = df_row[traj_def]
    start_idx = df_row['up_down_analysis'].max_run_up_start_idx
    end_idx = df_row['up_down_analysis'].max_run_up_end_idx
    orient = np.rad2deg(rgetattr(traj, euler_def))
    true_axial = np.rad2deg(getattr(traj, 'true_axial_rot'))
    apparent_orient_diff = orient[end_idx, 2] - orient[start_idx, 2]
    true_axial_diff = true_axial[end_idx] - true_axial[start_idx]
    add_axial_rot = add_axial_rot_fnc(orient, start_idx, end_idx)

    long, lat = path_fnc(orient, start_idx, end_idx)

    # compute the area
    mid_ix = int((start_idx + end_idx) / 2)
    sp = SphericalPolygon.from_lonlat(long, lat, center=(orient[mid_ix, 0], orient[mid_ix, 1]/2))
    area = np.rad2deg(sp.area())

    # if the actual path and the "euler" path cross each other the spherical_geometry polygon incorrectly estimates the
    # area
    while area > 180:
        area -= 180

    return apparent_orient_diff, true_axial_diff, area, add_axial_rot, sp.is_clockwise()
def summary_plotter(shoulder_trajs,
                    traj_def,
                    x_ind_def,
                    y_ind_def,
                    x_cmn_def,
                    y_cmn_def,
                    decomp_method,
                    sub_rot,
                    ax,
                    ind_plotter_fnc,
                    avg_color,
                    quat_avg_color,
                    error_bars='se',
                    **error_bar_kwargs):
    # plot individual trajectories
    ind_traj_plot_lines = shoulder_trajs.apply(
        ind_plotter_fnc,
        args=[traj_def, x_ind_def, y_ind_def, decomp_method, sub_rot, ax])
    # common x-axis : only look at the first trajectory because it will be the same for all
    x_cmn = getattr(shoulder_trajs.iloc[0], x_cmn_def)
    # Extract the y-values spanning the common x-axis. This goes to each quaternion interpolated trajectory, and applies
    # decomp_method and sub_rot. The results are then averaged below. This is technically not correct because
    # mathematically it doesn't make sense to average PoE, Elevation, etc. but this is how other papers handle this step
    y_cmn_all = np.stack(
        shoulder_trajs.apply(
            extract_sub_rot,
            args=[traj_def, y_cmn_def, decomp_method, sub_rot]), 0)
    traj_mean, traj_std, traj_se = traj_stats(y_cmn_all)
    if error_bars.lower() == 'se':
        agg_lines = ax.errorbar(x_cmn,
                                np.rad2deg(traj_mean),
                                yerr=np.rad2deg(traj_se),
                                capsize=2,
                                color=avg_color,
                                zorder=4,
                                lw=2,
                                **error_bar_kwargs)
    elif error_bars.lower() == 'std':
        agg_lines = ax.errorbar(x_cmn,
                                np.rad2deg(traj_mean),
                                yerr=np.rad2deg(traj_std),
                                capsize=2,
                                color=avg_color,
                                zorder=4,
                                lw=2,
                                **error_bar_kwargs)
    else:
        raise ValueError('The error bars can either be SE or STD.')
    # So here the individual interpolated trajectories are averaged via quaternions. This is mathematically correct.
    # Then the averaged trajectory is decomposed according to decomp_method and sub_rot. One could then use this to
    # compute SD and SE, but I don't go that far since almost always the quaternion mean and the mean as computed above
    # match very well. But I do overlay the quaternion mean as a sanity check
    if decomp_method != 'true_axial_rot' and decomp_method != 'induced_axial_rot':
        mean_traj_quat = quat_mean_trajs(
            np.stack(shoulder_trajs.apply(extract_interp_quat_traj,
                                          args=[traj_def, y_cmn_def]),
                     axis=0))
        mean_traj_pos = np.zeros((mean_traj_quat.size, mean_traj_quat.size))
        mean_traj_pose = PoseTrajectory.from_quat(
            mean_traj_pos, q.as_float_array(mean_traj_quat))
        mean_y_quat = rgetattr(mean_traj_pose, decomp_method)[:, sub_rot]
        quat_mean_lines = ax.plot(x_cmn,
                                  np.rad2deg(mean_y_quat),
                                  color=quat_avg_color,
                                  zorder=5,
                                  lw=2)
    else:
        quat_mean_lines = None

    return ind_traj_plot_lines, agg_lines, quat_mean_lines