def test_average_rt():
    # TODO: investigate why this does not work
    # angles = Angles.from_random_data(size=(3, 1, 100))
    # or
    # angles = Angles(np.arange(300).reshape((3, 1, 100)))
    angles = Angles(np.random.rand(3, 1, 100))
    seq = "xyz"

    rt = Rototrans.from_euler_angles(angles, seq)
    rt_mean = Rototrans.from_averaged_rototrans(rt)
    angles_mean = Angles.from_rototrans(rt_mean, seq).isel(time=0)

    angles_mean_ref = Angles.from_rototrans(rt, seq).mean(dim="time")

    np.testing.assert_array_almost_equal(angles_mean,
                                         angles_mean_ref,
                                         decimal=2)
def test_average_rt():
    seq = "xyz"
    angles_size = (3, 1, 100)
    ref_angles_from_rt_mean = [0.25265133, 0.57436872, 0.79133042]

    restart_seed()
    angles = Angles.from_random_data(size=angles_size)
    # min-max normalization to keep the angles low
    angles = angles.pipe(lambda x: (x - x.min()) / (x.max() - x.min()))

    rt = Rototrans.from_euler_angles(angles, seq)
    rt_mean = Rototrans.from_averaged_rototrans(rt)
    angles_from_rt_mean = Angles.from_rototrans(rt_mean, seq)

    np.testing.assert_almost_equal(angles_from_rt_mean.data.ravel(),
                                   ref_angles_from_rt_mean,
                                   decimal=4)