def test_observation_space(): observation_space = ObservationSpace(np.array([1, 10]), -10, 10) # testing that contains works assert observation_space.contains(np.ones([1, 10])) assert not observation_space.contains(np.ones([2, 10])) assert not observation_space.contains(np.ones([1, 10]) * 100) assert not observation_space.contains(np.ones([1, 1, 10])) # is_valid_index assert observation_space.is_valid_index(np.array([0, 9])) assert observation_space.is_valid_index(np.array([0, 0])) assert not observation_space.is_valid_index(np.array([1, 8])) assert not observation_space.is_valid_index(np.array([0, 10])) assert not observation_space.is_valid_index(np.array([-1, 6]))
def validate_input_observation_space(self, input_observation_space: ObservationSpace): if len(self.stack) > 0 and not input_observation_space.contains(self.stack[-1]): raise ValueError("The given input observation space is different than the observations already stored in" "the filters memory") if input_observation_space.num_dimensions <= self.stacking_axis: raise ValueError("The stacking axis is larger than the number of dimensions in the observation space")