Beispiel #1
0
def test_step_iter(mock_data, data_format: str):
    rewards, states, observations, actions, hidden, policy_infos = mock_data

    ro = StepSequence(
        rewards=rewards,
        observations=observations,
        states=states,
        actions=actions,
        policy_infos=policy_infos,
        hidden=hidden,
        data_format=data_format,
    )

    assert len(ro) == 5

    for i, step in enumerate(ro):
        assert step.reward == rewards[i]
        # Check current and next
        assert (step.observation == to_format(observations[i],
                                              data_format)).all()
        assert (step.next_observation == to_format(observations[i + 1],
                                                   data_format)).all()
        # Check dict sub element
        assert (step.policy_info.mean == to_format(policy_infos[i]["mean"],
                                                   data_format)).all()
        assert (step.hidden[0] == to_format(hidden[i][0], data_format)).all()
Beispiel #2
0
def test_to_format(data_type):
    # Create some tensors to convert
    ndarray = np.random.rand(3, 2).astype(dtype=np.float64)
    tensor = to.rand(3, 2).type(dtype=to.float64)

    # Test the conversion and typing from numpy to PyTorch
    converted_ndarray = to_format(ndarray, 'torch', data_type[0])
    assert isinstance(converted_ndarray, to.Tensor)
    new_type = to.float64 if data_type[0] is None else data_type[0]  # passing None must not change the type
    assert converted_ndarray.dtype == new_type

    # Test the conversion and typing from PyTorch to numpy
    converted_tensor = to_format(tensor, 'numpy', data_type[1])
    assert isinstance(converted_tensor, np.ndarray)
    new_type = np.float64 if data_type[1] is None else data_type[1]  # passing None must not change the type
    assert converted_tensor.dtype == new_type
Beispiel #3
0
def test_namedtuple(data_format):
    hid_nt = [DummyNT(*it) for it in hidden]

    ro = StepSequence(rewards=rewards, hidden=hid_nt, data_format=data_format)

    assert isinstance(ro.hidden, DummyNT)

    for i, step in enumerate(ro):
        assert isinstance(step.hidden, DummyNT)
        assert (step.hidden.part1 == to_format(hid_nt[i].part1,
                                               data_format)).all()
Beispiel #4
0
def test_namedtuple(mock_data, data_format: str):
    rewards, states, observations, actions, hidden, policy_infos = mock_data

    hid_nt = [DummyNT(*it) for it in hidden]

    ro = StepSequence(rewards=rewards,
                      actions=actions,
                      observations=observations,
                      hidden=hid_nt,
                      data_format=data_format)

    assert isinstance(ro.hidden, DummyNT)

    for i, step in enumerate(ro):
        assert isinstance(step.hidden, DummyNT)
        assert (step.hidden.part1 == to_format(hid_nt[i].part1,
                                               data_format)).all()
    def convert(self, data_format: str, data_type=None):
        """
        Convert data to specified format.

        :param data_format: torch to use Tensors, numpy to use ndarrays
        :param data_type: optional torch/numpy dtype for data. When `None` is passed, the data type is left unchanged.
        """
        if data_format not in {'torch', 'numpy'}:
            raise pyrado.ValueErr(given=data_format,
                                  eq_constraint="'torch' or 'numpy'")

        if self._data_format == data_format:
            return
        self._data_format = data_format
        for dn in self._data_names:
            self.__dict__[dn] = self.__map_tensors(
                lambda t: to_format(t, data_format, data_type),
                self.__dict__[dn])
    def add_data(self,
                 name: str,
                 value=None,
                 item_shape: tuple = None,
                 with_after_last: bool = False):
        """
        Add a new data field to the step sequence. Can also be used to replace data in an existing field.

        :param name: sting for the name
        :param value: the data
        :param item_shape: shape to store the data in
        :param with_after_last: `True` if there is one more element than the length (e.g. last observation)
        """
        if name in self._data_names:
            raise pyrado.ValueErr(
                msg=f'Trying to add a duplicate data field for {name}')

        if value is None:
            # Compute desired step length
            ro_length = self.length
            if with_after_last:
                ro_length += 1

            # Create zero-filled
            if self._data_format == 'torch':
                value = to.zeros(to.Size([ro_length]) + to.Size(item_shape))
            else:
                value = np.array((ro_length, ) + item_shape)

        else:
            # Check type of data
            self._validate_data_size(name, value)

            if not isinstance(value, (np.ndarray, to.Tensor)):
                # Stack into one array/tensor
                value = stack_to_format(value, self._data_format)
            else:
                # Ensure right array format
                value = to_format(value, self._data_format)

        # Store in dict
        self._data_names.append(name)
        self.__dict__[name] = value