Beispiel #1
0
def batch_obs(
    observations: List[DictTree],
    device: Optional[torch.device] = None,
) -> TensorDict:
    r"""Transpose a batch of observation dicts to a dict of batched
    observations.

    Args:
        observations:  list of dicts of observations.
        device: The torch.device to put the resulting tensors on.
            Will not move the tensors if None

    Returns:
        transposed dict of torch.Tensor of observations.
    """
    batch: DefaultDict[str, List] = defaultdict(list)

    for obs in observations:
        for sensor in obs:
            batch[sensor].append(torch.as_tensor(obs[sensor]))

    batch_t: TensorDict = TensorDict()

    for sensor in batch:
        batch_t[sensor] = torch.stack(batch[sensor], dim=0)

    return batch_t.map(lambda v: v.to(device))
def test_tensor_dict_constructor():
    dict_tree = dict(a=torch.randn(2, 2),
                     b=dict(c=dict(d=np.random.randn(3, 3))))
    tensor_dict = TensorDict.from_tree(dict_tree)

    assert torch.is_tensor(tensor_dict["a"])
    assert isinstance(tensor_dict["b"], TensorDict)
    assert isinstance(tensor_dict["b"]["c"], TensorDict)
    assert torch.is_tensor(tensor_dict["b"]["c"]["d"])
def test_tensor_dict_map():
    dict_tree = dict(a=dict(b=[0]))
    tensor_dict = TensorDict.from_tree(dict_tree)

    res = tensor_dict.map(lambda x: x + 1)
    assert (res["a"]["b"] == 1).all()

    tensor_dict.map_in_place(lambda x: x + 1)

    assert res == tensor_dict
def test_tensor_dict_str_index():
    dict_tree = dict(a=torch.randn(2, 2), b=dict(c=dict(d=torch.randn(3, 3))))
    tensor_dict = TensorDict.from_tree(dict_tree)

    x = torch.randn(5, 5)
    tensor_dict["a"] = x
    assert (tensor_dict["a"] == x).all()

    with pytest.raises(KeyError):
        _ = tensor_dict["c"]
Beispiel #5
0
def batch_obs(
    observations: List[DictTree],
    device: Optional[torch.device] = None,
    cache: Optional[ObservationBatchingCache] = None,
) -> TensorDict:
    r"""Transpose a batch of observation dicts to a dict of batched
    observations.

    Args:
        observations:  list of dicts of observations.
        device: The torch.device to put the resulting tensors on.
            Will not move the tensors if None
        cache: An ObservationBatchingCache.  This enables faster
            stacking of observations and cpu-gpu transfer as it
            maintains a correctly sized tensor for the batched
            observations that is pinned to cuda memory.

    Returns:
        transposed dict of torch.Tensor of observations.
    """
    batch_t: TensorDict = TensorDict()
    if cache is None:
        batch: DefaultDict[str, List] = defaultdict(list)

    for i, obs in enumerate(observations):
        for sensor_name, sensor in obs.items():
            sensor = torch.as_tensor(sensor)
            if cache is None:
                batch[sensor_name].append(sensor)
            else:
                if sensor_name not in batch_t:
                    batch_t[sensor_name] = cache.get(len(observations),
                                                     sensor_name, sensor,
                                                     device)

                batch_t[sensor_name][i].copy_(sensor)

    if cache is None:
        for sensor in batch:
            batch_t[sensor] = torch.stack(batch[sensor], dim=0)

    return batch_t.map(lambda v: v.to(device, non_blocking=True))
def test_tensor_dict_index():
    dict_tree = dict(a=torch.randn(2, 2), b=dict(c=dict(d=torch.randn(3, 3))))
    tensor_dict = TensorDict.from_tree(dict_tree)

    with pytest.raises(KeyError):
        tensor_dict["b"][0] = dict(q=torch.randn(3))

    tmp = dict(c=dict(d=torch.randn(3)))
    tensor_dict["b"][0] = tmp
    assert torch.allclose(tensor_dict["b"]["c"]["d"][0], tmp["c"]["d"])
    assert not torch.allclose(tensor_dict["b"]["c"]["d"][1], tmp["c"]["d"])

    tensor_dict["b"]["c"]["x"] = torch.randn(5, 5)
    with pytest.raises(KeyError):
        tensor_dict["b"][1] = tmp

    tensor_dict["b"].set(1, tmp, strict=False)
    assert torch.allclose(tensor_dict["b"]["c"]["d"][1], tmp["c"]["d"])

    tmp = dict(c=dict(d=torch.randn(1, 3)))
    del tensor_dict["b"]["c"]["x"]
    tensor_dict["b"][2:3] = tmp
    assert torch.allclose(tensor_dict["b"]["c"]["d"][2:3], tmp["c"]["d"])
Beispiel #7
0
def batch_obs(
    observations: List[DictTree],
    device: Optional[torch.device] = None,
    cache: Optional[ObservationBatchingCache] = None,
) -> TensorDict:
    r"""Transpose a batch of observation dicts to a dict of batched
    observations.

    Args:
        observations:  list of dicts of observations.
        device: The torch.device to put the resulting tensors on.
            Will not move the tensors if None
        cache: An ObservationBatchingCache.  This enables faster
            stacking of observations and cpu-gpu transfer as it
            maintains a correctly sized tensor for the batched
            observations that is pinned to cuda memory.

    Returns:
        transposed dict of torch.Tensor of observations.
    """
    batch_t: TensorDict = TensorDict()
    if cache is None:
        batch: DefaultDict[str, List] = defaultdict(list)

    obs = observations[0]
    # Order sensors by size, stack and move the largest first
    sensor_names = sorted(
        obs.keys(),
        key=lambda name: 1
        if isinstance(obs[name], numbers.Number) else np.prod(obs[name].shape),
        reverse=True,
    )

    for sensor_name in sensor_names:
        for i, obs in enumerate(observations):
            sensor = obs[sensor_name]
            if cache is None:
                batch[sensor_name].append(torch.as_tensor(sensor))
            else:
                if sensor_name not in batch_t:
                    batch_t[sensor_name] = cache.get(
                        len(observations),
                        sensor_name,
                        torch.as_tensor(sensor),
                        device,
                    )

                # Use isinstance(sensor, np.ndarray) here instead of
                # np.asarray as this is quickier for the more common
                # path of sensor being an np.ndarray
                # np.asarray is ~3x slower than checking
                if isinstance(sensor, np.ndarray):
                    batch_t[sensor_name][i] = sensor
                elif torch.is_tensor(sensor):
                    batch_t[sensor_name][i].copy_(sensor, non_blocking=True)
                # If the sensor wasn't a tensor, then it's some CPU side data
                # so use a numpy array
                else:
                    batch_t[sensor_name][i] = np.asarray(sensor)

        # With the batching cache, we use pinned mem
        # so we can start the move to the GPU async
        # and continue stacking other things with it
        if cache is not None:
            # If we were using a numpy array to do indexing and copying,
            # convert back to torch tensor
            # We know that batch_t[sensor_name] is either an np.ndarray
            # or a torch.Tensor, so this is faster than torch.as_tensor
            if isinstance(batch_t[sensor_name], np.ndarray):
                batch_t[sensor_name] = torch.from_numpy(batch_t[sensor_name])

            batch_t[sensor_name] = batch_t[sensor_name].to(device,
                                                           non_blocking=True)

    if cache is None:
        for sensor in batch:
            batch_t[sensor] = torch.stack(batch[sensor], dim=0)

        batch_t.map_in_place(lambda v: v.to(device))

    return batch_t
def test_tensor_dict_to_tree():
    dict_tree = dict(a=torch.randn(2, 2), b=dict(c=dict(d=torch.randn(3, 3))))

    assert dict_tree == TensorDict.from_tree(dict_tree).to_tree()
    def __init__(
        self,
        numsteps,
        num_envs,
        observation_space,
        action_space,
        recurrent_hidden_state_size,
        num_recurrent_layers=1,
        is_double_buffered: bool = False,
    ):
        self.buffers = TensorDict()
        self.buffers["observations"] = TensorDict()

        for sensor in observation_space.spaces:
            self.buffers["observations"][sensor] = torch.from_numpy(
                np.zeros(
                    (
                        numsteps + 1,
                        num_envs,
                        *observation_space.spaces[sensor].shape,
                    ),
                    dtype=observation_space.spaces[sensor].dtype,
                ))

        self.buffers["recurrent_hidden_states"] = torch.zeros(
            numsteps + 1,
            num_envs,
            num_recurrent_layers,
            recurrent_hidden_state_size,
        )

        self.buffers["rewards"] = torch.zeros(numsteps + 1, num_envs, 1)
        self.buffers["value_preds"] = torch.zeros(numsteps + 1, num_envs, 1)
        self.buffers["returns"] = torch.zeros(numsteps + 1, num_envs, 1)

        self.buffers["action_log_probs"] = torch.zeros(numsteps + 1, num_envs,
                                                       1)
        if action_space.__class__.__name__ == "ActionSpace":
            action_shape = 1
        else:
            action_shape = action_space.shape[0]

        self.buffers["actions"] = torch.zeros(numsteps + 1, num_envs,
                                              action_shape)
        self.buffers["prev_actions"] = torch.zeros(numsteps + 1, num_envs,
                                                   action_shape)
        if action_space.__class__.__name__ == "ActionSpace":
            self.buffers["actions"] = self.buffers["actions"].long()
            self.buffers["prev_actions"] = self.buffers["prev_actions"].long()

        self.buffers["masks"] = torch.zeros(numsteps + 1,
                                            num_envs,
                                            1,
                                            dtype=torch.bool)

        self.is_double_buffered = is_double_buffered
        self._nbuffers = 2 if is_double_buffered else 1
        self._num_envs = num_envs

        assert (self._num_envs % self._nbuffers) == 0

        self.numsteps = numsteps
        self.current_rollout_step_idxs = [0 for _ in range(self._nbuffers)]
class RolloutStorage:
    r"""Class for storing rollout information for RL trainers."""
    def __init__(
        self,
        numsteps,
        num_envs,
        observation_space,
        action_space,
        recurrent_hidden_state_size,
        num_recurrent_layers=1,
        is_double_buffered: bool = False,
    ):
        self.buffers = TensorDict()
        self.buffers["observations"] = TensorDict()

        for sensor in observation_space.spaces:
            self.buffers["observations"][sensor] = torch.from_numpy(
                np.zeros(
                    (
                        numsteps + 1,
                        num_envs,
                        *observation_space.spaces[sensor].shape,
                    ),
                    dtype=observation_space.spaces[sensor].dtype,
                ))

        self.buffers["recurrent_hidden_states"] = torch.zeros(
            numsteps + 1,
            num_envs,
            num_recurrent_layers,
            recurrent_hidden_state_size,
        )

        self.buffers["rewards"] = torch.zeros(numsteps + 1, num_envs, 1)
        self.buffers["value_preds"] = torch.zeros(numsteps + 1, num_envs, 1)
        self.buffers["returns"] = torch.zeros(numsteps + 1, num_envs, 1)

        self.buffers["action_log_probs"] = torch.zeros(numsteps + 1, num_envs,
                                                       1)
        if action_space.__class__.__name__ == "ActionSpace":
            action_shape = 1
        else:
            action_shape = action_space.shape[0]

        self.buffers["actions"] = torch.zeros(numsteps + 1, num_envs,
                                              action_shape)
        self.buffers["prev_actions"] = torch.zeros(numsteps + 1, num_envs,
                                                   action_shape)
        if action_space.__class__.__name__ == "ActionSpace":
            self.buffers["actions"] = self.buffers["actions"].long()
            self.buffers["prev_actions"] = self.buffers["prev_actions"].long()

        self.buffers["masks"] = torch.zeros(numsteps + 1,
                                            num_envs,
                                            1,
                                            dtype=torch.bool)

        self.is_double_buffered = is_double_buffered
        self._nbuffers = 2 if is_double_buffered else 1
        self._num_envs = num_envs

        assert (self._num_envs % self._nbuffers) == 0

        self.numsteps = numsteps
        self.current_rollout_step_idxs = [0 for _ in range(self._nbuffers)]

    @property
    def current_rollout_step_idx(self) -> int:
        assert all(s == self.current_rollout_step_idxs[0]
                   for s in self.current_rollout_step_idxs)
        return self.current_rollout_step_idxs[0]

    def to(self, device):
        self.buffers.map_in_place(lambda v: v.to(device))

    def insert(
        self,
        next_observations=None,
        next_recurrent_hidden_states=None,
        actions=None,
        action_log_probs=None,
        value_preds=None,
        rewards=None,
        next_masks=None,
        buffer_index: int = 0,
    ):
        if not self.is_double_buffered:
            assert buffer_index == 0

        next_step = dict(
            observations=next_observations,
            recurrent_hidden_states=next_recurrent_hidden_states,
            prev_actions=actions,
            masks=next_masks,
        )

        current_step = dict(
            actions=actions,
            action_log_probs=action_log_probs,
            value_preds=value_preds,
            rewards=rewards,
        )

        next_step = {k: v for k, v in next_step.items() if v is not None}
        current_step = {k: v for k, v in current_step.items() if v is not None}

        env_slice = slice(
            int(buffer_index * self._num_envs / self._nbuffers),
            int((buffer_index + 1) * self._num_envs / self._nbuffers),
        )

        if len(next_step) > 0:
            self.buffers.set(
                (self.current_rollout_step_idxs[buffer_index] + 1, env_slice),
                next_step,
                strict=False,
            )

        if len(current_step) > 0:
            self.buffers.set(
                (self.current_rollout_step_idxs[buffer_index], env_slice),
                current_step,
                strict=False,
            )

    def advance_rollout(self, buffer_index: int = 0):
        self.current_rollout_step_idxs[buffer_index] += 1

    def after_update(self):
        self.buffers[0] = self.buffers[self.current_rollout_step_idx]

        self.current_rollout_step_idxs = [
            0 for _ in self.current_rollout_step_idxs
        ]

    def compute_returns(self, next_value, use_gae, gamma, tau):
        if use_gae:
            self.buffers["value_preds"][
                self.current_rollout_step_idx] = next_value
            gae = 0
            for step in reversed(range(self.current_rollout_step_idx)):
                delta = (self.buffers["rewards"][step] +
                         gamma * self.buffers["value_preds"][step + 1] *
                         self.buffers["masks"][step + 1] -
                         self.buffers["value_preds"][step])
                gae = (delta +
                       gamma * tau * gae * self.buffers["masks"][step + 1])
                self.buffers["returns"][step] = (
                    gae + self.buffers["value_preds"][step])
        else:
            self.buffers["returns"][self.current_rollout_step_idx] = next_value
            for step in reversed(range(self.current_rollout_step_idx)):
                self.buffers["returns"][step] = (
                    gamma * self.buffers["returns"][step + 1] *
                    self.buffers["masks"][step + 1] +
                    self.buffers["rewards"][step])

    def recurrent_generator(self, advantages, num_mini_batch) -> TensorDict:
        num_environments = advantages.size(1)
        assert num_environments >= num_mini_batch, (
            "Trainer requires the number of environments ({}) "
            "to be greater than or equal to the number of "
            "trainer mini batches ({}).".format(num_environments,
                                                num_mini_batch))
        if num_environments % num_mini_batch != 0:
            warnings.warn(
                "Number of environments ({}) is not a multiple of the"
                " number of mini batches ({}).  This results in mini batches"
                " of different sizes, which can harm training performance.".
                format(num_environments, num_mini_batch))
        for inds in torch.randperm(num_environments).chunk(num_mini_batch):
            batch = self.buffers[0:self.current_rollout_step_idx, inds]
            batch["advantages"] = advantages[0:self.current_rollout_step_idx,
                                             inds]
            batch["recurrent_hidden_states"] = batch[
                "recurrent_hidden_states"][0:1]

            yield batch.map(lambda v: v.flatten(0, 1))