def test_observation_error_increases_with_real_trajectory_length(self): observations_real_short = np.array([[1], [2]]) observations_real_long = np.array([[1], [2], [3]]) observations_sim = np.array([[0], [1]]) (traj_real_short, traj_real_long, traj_sim) = map( self._make_trajectory, (observations_real_short, observations_real_long, observations_sim), ) error_small = simple.calculate_observation_error( real_trajectories=[traj_real_short], sim_trajectories=[traj_sim]) error_big = simple.calculate_observation_error( real_trajectories=[traj_real_long], sim_trajectories=[traj_sim]) np.testing.assert_array_less(error_small, error_big)
def test_observation_error_positive_for_different_trajectories(self): observations1 = np.array([[1], [2], [3]]) observations2 = np.array([[0], [2], [3]]) (traj1, traj2) = map(self._make_trajectory, (observations1, observations2)) error = simple.calculate_observation_error([traj1], [traj2]) np.testing.assert_array_less([0], error)
def test_observation_error_reduces_over_trajectories(self): observations1 = np.array([[1], [2], [3]]) observations2 = np.array([[0], [2], [3]]) (traj1, traj2) = map(self._make_trajectory, (observations1, observations2)) error = simple.calculate_observation_error([traj1, traj1], [traj2, traj2]) self.assertEqual(error.shape, (1, ))
def test_observation_error_dims_correspond_to_observation_dims(self): observations1 = np.array([[0, 1, 0], [0, 2, 0], [0, 3, 0]]) observations2 = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) (traj1, traj2) = map(self._make_trajectory, (observations1, observations2)) error = simple.calculate_observation_error([traj1], [traj2]) self.assertEqual(error.shape, (3,)) np.testing.assert_array_almost_equal(error[0], 0) self.assertFalse(np.allclose(error[1], 0)) np.testing.assert_array_almost_equal(error[2], 0)
def test_observation_error_zero_for_same_trajectories(self): observations = np.array([[0], [2], [1]]) (traj1, traj2) = map(self._make_trajectory, (observations, observations)) error = simple.calculate_observation_error([traj1], [traj2]) np.testing.assert_array_almost_equal(error, [0])