Exemplo n.º 1
0
def test_discrete_space():
    discrete_space = bindings.Discrete(12)

    assert not discrete_space.contains(bindings.Sample([-1])), \
        "Discrete object failed to verify if a sample belongs to its space"
    assert not discrete_space.contains(bindings.Sample([13])), \
        "Discrete object failed to verify if a sample belongs to its space"

    for n in range(50):
        sample = discrete_space.sample()
        assert sample.getBuffer_i().size() == 1, \
            "Wrong size of the sample extracted from a discrete space"
        assert isinstance(sample.getBuffer_i()[0], int), \
            "Wrong data type of the sample extracted from a discrete space"
        assert discrete_space.contains(sample), \
            "Sampled data is not contained in the discrete space object that created it"
Exemplo n.º 2
0
def test_sample(create_std_vector):
    vector = create_std_vector
    sample = bindings.Sample(vector)

    for i in range(0, vector.size()-1):
        assert sample.getBuffer_d()[i] == vector[i], \
            "Sample object does not contain correct data"

    sample.getBuffer_d()[2] = 42
    assert sample.get_d(2).value() == 42, "Failed to insert data in the Sample object"
    assert sample.getBuffer_d()[2] == 42, "Failed to update data of a Sample object"
Exemplo n.º 3
0
def test_box_space():
    size = 4
    box = bindings.Box(-1, 42, [size])

    # By default the data precision of python list is float. Force double.
    assert box.contains(bindings.Sample(bindings.Vector_d([0, pi, 12, 42]))), \
        "Box object failed to verify if a sample belongs to its space"
    assert not box.contains(bindings.Sample(bindings.Vector_d([0, pi, 12, 43]))), \
        "Box object failed to verify if a sample belongs to its space"
    assert not box.contains(bindings.Sample(bindings.Vector_d([0]))), \
        "Box object failed to verify if a sample belongs to its space"
    assert not box.contains(bindings.Sample(bindings.Vector_d([0, pi, 12, 43, 0]))), \
        "Box object failed to verify if a sample belongs to its space"

    for n in range(50):
        sample = box.sample()
        assert sample.getBuffer_d().size() == size, \
            "Wrong size of the sample extracted from a box space"
        assert isinstance(sample.getBuffer_d()[0], float), \
            "Wrong data type of the sample extracted from the box space"
        assert box.contains(sample), \
            "Sampled data is not contained in the box space object that created it"
Exemplo n.º 4
0
    def step(self, action: Action) -> State:
        assert self.action_space.contains(action), \
            "The action does not belong to the action space"

        # The bindings do not accept yet numpy types as arguments. We need to covert
        # numpy variables to the closer python type.

        # Check if the input variable is a numpy type
        is_numpy = type(action).__module__ == np.__name__

        if is_numpy:
            if isinstance(action, np.ndarray):
                action = action.tolist()
            elif isinstance(action, np.number):
                action = action.item()
            else:
                assert False

        # Actions must be std::vector objects, so if the passed action is a scalar
        # we have to store it inside a list object before passing it to the bindings
        if isinstance(action, Number):
            action_list = [action]
        else:
            action_list = list(action)

        # Create the gympp::Sample object
        action_buffer = getattr(bindings, 'Vector' + self._act_dt)(action_list)
        action_sample = bindings.Sample(action_buffer)

        # Execute the step and get the std::optional<gympp::State> object
        state_optional = self.gympp_env.step(action_sample)
        assert state_optional.has_value()

        # Get the gympp::State
        state = state_optional.value()

        # Get the std::vector buffer of gympp::Observation
        observation_vector = getattr(state.observation, 'getBuffer' + self._obs_dt)()
        assert observation_vector, "Failed to get the observation buffer"
        assert observation_vector.size() > 0, "The observation does not contain elements"

        # Convert the SWIG type to a list
        observation_list = list(observation_vector)

        # Convert the observation to a numpy array (this is the only required copy)
        if isinstance(self.observation_space, gym.spaces.Box):
            observation = np.array(observation_list)
        elif isinstance(self.observation_space, gym.spaces.Discrete):
            assert observation_vector.size() == 1, "The buffer has the wrong dimension"
            observation = observation_list[0]
        else:
            assert False, "Space not supported"

        assert self.observation_space.contains(observation), \
            "The returned observation does not belong to the space"

        # Create the info dict
        info = {'gympp': state.info}

        # Return the tuple
        return State((observation, state.reward, state.done, info))