def batch_obs( observations: List[DictTree], device: Optional[torch.device] = None, ) -> TensorDict: r"""Transpose a batch of observation dicts to a dict of batched observations. Args: observations: list of dicts of observations. device: The torch.device to put the resulting tensors on. Will not move the tensors if None Returns: transposed dict of torch.Tensor of observations. """ batch: DefaultDict[str, List] = defaultdict(list) for obs in observations: for sensor in obs: batch[sensor].append(torch.as_tensor(obs[sensor])) batch_t: TensorDict = TensorDict() for sensor in batch: batch_t[sensor] = torch.stack(batch[sensor], dim=0) return batch_t.map(lambda v: v.to(device))
def test_tensor_dict_constructor(): dict_tree = dict(a=torch.randn(2, 2), b=dict(c=dict(d=np.random.randn(3, 3)))) tensor_dict = TensorDict.from_tree(dict_tree) assert torch.is_tensor(tensor_dict["a"]) assert isinstance(tensor_dict["b"], TensorDict) assert isinstance(tensor_dict["b"]["c"], TensorDict) assert torch.is_tensor(tensor_dict["b"]["c"]["d"])
def test_tensor_dict_map(): dict_tree = dict(a=dict(b=[0])) tensor_dict = TensorDict.from_tree(dict_tree) res = tensor_dict.map(lambda x: x + 1) assert (res["a"]["b"] == 1).all() tensor_dict.map_in_place(lambda x: x + 1) assert res == tensor_dict
def test_tensor_dict_str_index(): dict_tree = dict(a=torch.randn(2, 2), b=dict(c=dict(d=torch.randn(3, 3)))) tensor_dict = TensorDict.from_tree(dict_tree) x = torch.randn(5, 5) tensor_dict["a"] = x assert (tensor_dict["a"] == x).all() with pytest.raises(KeyError): _ = tensor_dict["c"]
def batch_obs( observations: List[DictTree], device: Optional[torch.device] = None, cache: Optional[ObservationBatchingCache] = None, ) -> TensorDict: r"""Transpose a batch of observation dicts to a dict of batched observations. Args: observations: list of dicts of observations. device: The torch.device to put the resulting tensors on. Will not move the tensors if None cache: An ObservationBatchingCache. This enables faster stacking of observations and cpu-gpu transfer as it maintains a correctly sized tensor for the batched observations that is pinned to cuda memory. Returns: transposed dict of torch.Tensor of observations. """ batch_t: TensorDict = TensorDict() if cache is None: batch: DefaultDict[str, List] = defaultdict(list) for i, obs in enumerate(observations): for sensor_name, sensor in obs.items(): sensor = torch.as_tensor(sensor) if cache is None: batch[sensor_name].append(sensor) else: if sensor_name not in batch_t: batch_t[sensor_name] = cache.get(len(observations), sensor_name, sensor, device) batch_t[sensor_name][i].copy_(sensor) if cache is None: for sensor in batch: batch_t[sensor] = torch.stack(batch[sensor], dim=0) return batch_t.map(lambda v: v.to(device, non_blocking=True))
def test_tensor_dict_index(): dict_tree = dict(a=torch.randn(2, 2), b=dict(c=dict(d=torch.randn(3, 3)))) tensor_dict = TensorDict.from_tree(dict_tree) with pytest.raises(KeyError): tensor_dict["b"][0] = dict(q=torch.randn(3)) tmp = dict(c=dict(d=torch.randn(3))) tensor_dict["b"][0] = tmp assert torch.allclose(tensor_dict["b"]["c"]["d"][0], tmp["c"]["d"]) assert not torch.allclose(tensor_dict["b"]["c"]["d"][1], tmp["c"]["d"]) tensor_dict["b"]["c"]["x"] = torch.randn(5, 5) with pytest.raises(KeyError): tensor_dict["b"][1] = tmp tensor_dict["b"].set(1, tmp, strict=False) assert torch.allclose(tensor_dict["b"]["c"]["d"][1], tmp["c"]["d"]) tmp = dict(c=dict(d=torch.randn(1, 3))) del tensor_dict["b"]["c"]["x"] tensor_dict["b"][2:3] = tmp assert torch.allclose(tensor_dict["b"]["c"]["d"][2:3], tmp["c"]["d"])
def batch_obs( observations: List[DictTree], device: Optional[torch.device] = None, cache: Optional[ObservationBatchingCache] = None, ) -> TensorDict: r"""Transpose a batch of observation dicts to a dict of batched observations. Args: observations: list of dicts of observations. device: The torch.device to put the resulting tensors on. Will not move the tensors if None cache: An ObservationBatchingCache. This enables faster stacking of observations and cpu-gpu transfer as it maintains a correctly sized tensor for the batched observations that is pinned to cuda memory. Returns: transposed dict of torch.Tensor of observations. """ batch_t: TensorDict = TensorDict() if cache is None: batch: DefaultDict[str, List] = defaultdict(list) obs = observations[0] # Order sensors by size, stack and move the largest first sensor_names = sorted( obs.keys(), key=lambda name: 1 if isinstance(obs[name], numbers.Number) else np.prod(obs[name].shape), reverse=True, ) for sensor_name in sensor_names: for i, obs in enumerate(observations): sensor = obs[sensor_name] if cache is None: batch[sensor_name].append(torch.as_tensor(sensor)) else: if sensor_name not in batch_t: batch_t[sensor_name] = cache.get( len(observations), sensor_name, torch.as_tensor(sensor), device, ) # Use isinstance(sensor, np.ndarray) here instead of # np.asarray as this is quickier for the more common # path of sensor being an np.ndarray # np.asarray is ~3x slower than checking if isinstance(sensor, np.ndarray): batch_t[sensor_name][i] = sensor elif torch.is_tensor(sensor): batch_t[sensor_name][i].copy_(sensor, non_blocking=True) # If the sensor wasn't a tensor, then it's some CPU side data # so use a numpy array else: batch_t[sensor_name][i] = np.asarray(sensor) # With the batching cache, we use pinned mem # so we can start the move to the GPU async # and continue stacking other things with it if cache is not None: # If we were using a numpy array to do indexing and copying, # convert back to torch tensor # We know that batch_t[sensor_name] is either an np.ndarray # or a torch.Tensor, so this is faster than torch.as_tensor if isinstance(batch_t[sensor_name], np.ndarray): batch_t[sensor_name] = torch.from_numpy(batch_t[sensor_name]) batch_t[sensor_name] = batch_t[sensor_name].to(device, non_blocking=True) if cache is None: for sensor in batch: batch_t[sensor] = torch.stack(batch[sensor], dim=0) batch_t.map_in_place(lambda v: v.to(device)) return batch_t
def test_tensor_dict_to_tree(): dict_tree = dict(a=torch.randn(2, 2), b=dict(c=dict(d=torch.randn(3, 3)))) assert dict_tree == TensorDict.from_tree(dict_tree).to_tree()
def __init__( self, numsteps, num_envs, observation_space, action_space, recurrent_hidden_state_size, num_recurrent_layers=1, is_double_buffered: bool = False, ): self.buffers = TensorDict() self.buffers["observations"] = TensorDict() for sensor in observation_space.spaces: self.buffers["observations"][sensor] = torch.from_numpy( np.zeros( ( numsteps + 1, num_envs, *observation_space.spaces[sensor].shape, ), dtype=observation_space.spaces[sensor].dtype, )) self.buffers["recurrent_hidden_states"] = torch.zeros( numsteps + 1, num_envs, num_recurrent_layers, recurrent_hidden_state_size, ) self.buffers["rewards"] = torch.zeros(numsteps + 1, num_envs, 1) self.buffers["value_preds"] = torch.zeros(numsteps + 1, num_envs, 1) self.buffers["returns"] = torch.zeros(numsteps + 1, num_envs, 1) self.buffers["action_log_probs"] = torch.zeros(numsteps + 1, num_envs, 1) if action_space.__class__.__name__ == "ActionSpace": action_shape = 1 else: action_shape = action_space.shape[0] self.buffers["actions"] = torch.zeros(numsteps + 1, num_envs, action_shape) self.buffers["prev_actions"] = torch.zeros(numsteps + 1, num_envs, action_shape) if action_space.__class__.__name__ == "ActionSpace": self.buffers["actions"] = self.buffers["actions"].long() self.buffers["prev_actions"] = self.buffers["prev_actions"].long() self.buffers["masks"] = torch.zeros(numsteps + 1, num_envs, 1, dtype=torch.bool) self.is_double_buffered = is_double_buffered self._nbuffers = 2 if is_double_buffered else 1 self._num_envs = num_envs assert (self._num_envs % self._nbuffers) == 0 self.numsteps = numsteps self.current_rollout_step_idxs = [0 for _ in range(self._nbuffers)]
class RolloutStorage: r"""Class for storing rollout information for RL trainers.""" def __init__( self, numsteps, num_envs, observation_space, action_space, recurrent_hidden_state_size, num_recurrent_layers=1, is_double_buffered: bool = False, ): self.buffers = TensorDict() self.buffers["observations"] = TensorDict() for sensor in observation_space.spaces: self.buffers["observations"][sensor] = torch.from_numpy( np.zeros( ( numsteps + 1, num_envs, *observation_space.spaces[sensor].shape, ), dtype=observation_space.spaces[sensor].dtype, )) self.buffers["recurrent_hidden_states"] = torch.zeros( numsteps + 1, num_envs, num_recurrent_layers, recurrent_hidden_state_size, ) self.buffers["rewards"] = torch.zeros(numsteps + 1, num_envs, 1) self.buffers["value_preds"] = torch.zeros(numsteps + 1, num_envs, 1) self.buffers["returns"] = torch.zeros(numsteps + 1, num_envs, 1) self.buffers["action_log_probs"] = torch.zeros(numsteps + 1, num_envs, 1) if action_space.__class__.__name__ == "ActionSpace": action_shape = 1 else: action_shape = action_space.shape[0] self.buffers["actions"] = torch.zeros(numsteps + 1, num_envs, action_shape) self.buffers["prev_actions"] = torch.zeros(numsteps + 1, num_envs, action_shape) if action_space.__class__.__name__ == "ActionSpace": self.buffers["actions"] = self.buffers["actions"].long() self.buffers["prev_actions"] = self.buffers["prev_actions"].long() self.buffers["masks"] = torch.zeros(numsteps + 1, num_envs, 1, dtype=torch.bool) self.is_double_buffered = is_double_buffered self._nbuffers = 2 if is_double_buffered else 1 self._num_envs = num_envs assert (self._num_envs % self._nbuffers) == 0 self.numsteps = numsteps self.current_rollout_step_idxs = [0 for _ in range(self._nbuffers)] @property def current_rollout_step_idx(self) -> int: assert all(s == self.current_rollout_step_idxs[0] for s in self.current_rollout_step_idxs) return self.current_rollout_step_idxs[0] def to(self, device): self.buffers.map_in_place(lambda v: v.to(device)) def insert( self, next_observations=None, next_recurrent_hidden_states=None, actions=None, action_log_probs=None, value_preds=None, rewards=None, next_masks=None, buffer_index: int = 0, ): if not self.is_double_buffered: assert buffer_index == 0 next_step = dict( observations=next_observations, recurrent_hidden_states=next_recurrent_hidden_states, prev_actions=actions, masks=next_masks, ) current_step = dict( actions=actions, action_log_probs=action_log_probs, value_preds=value_preds, rewards=rewards, ) next_step = {k: v for k, v in next_step.items() if v is not None} current_step = {k: v for k, v in current_step.items() if v is not None} env_slice = slice( int(buffer_index * self._num_envs / self._nbuffers), int((buffer_index + 1) * self._num_envs / self._nbuffers), ) if len(next_step) > 0: self.buffers.set( (self.current_rollout_step_idxs[buffer_index] + 1, env_slice), next_step, strict=False, ) if len(current_step) > 0: self.buffers.set( (self.current_rollout_step_idxs[buffer_index], env_slice), current_step, strict=False, ) def advance_rollout(self, buffer_index: int = 0): self.current_rollout_step_idxs[buffer_index] += 1 def after_update(self): self.buffers[0] = self.buffers[self.current_rollout_step_idx] self.current_rollout_step_idxs = [ 0 for _ in self.current_rollout_step_idxs ] def compute_returns(self, next_value, use_gae, gamma, tau): if use_gae: self.buffers["value_preds"][ self.current_rollout_step_idx] = next_value gae = 0 for step in reversed(range(self.current_rollout_step_idx)): delta = (self.buffers["rewards"][step] + gamma * self.buffers["value_preds"][step + 1] * self.buffers["masks"][step + 1] - self.buffers["value_preds"][step]) gae = (delta + gamma * tau * gae * self.buffers["masks"][step + 1]) self.buffers["returns"][step] = ( gae + self.buffers["value_preds"][step]) else: self.buffers["returns"][self.current_rollout_step_idx] = next_value for step in reversed(range(self.current_rollout_step_idx)): self.buffers["returns"][step] = ( gamma * self.buffers["returns"][step + 1] * self.buffers["masks"][step + 1] + self.buffers["rewards"][step]) def recurrent_generator(self, advantages, num_mini_batch) -> TensorDict: num_environments = advantages.size(1) assert num_environments >= num_mini_batch, ( "Trainer requires the number of environments ({}) " "to be greater than or equal to the number of " "trainer mini batches ({}).".format(num_environments, num_mini_batch)) if num_environments % num_mini_batch != 0: warnings.warn( "Number of environments ({}) is not a multiple of the" " number of mini batches ({}). This results in mini batches" " of different sizes, which can harm training performance.". format(num_environments, num_mini_batch)) for inds in torch.randperm(num_environments).chunk(num_mini_batch): batch = self.buffers[0:self.current_rollout_step_idx, inds] batch["advantages"] = advantages[0:self.current_rollout_step_idx, inds] batch["recurrent_hidden_states"] = batch[ "recurrent_hidden_states"][0:1] yield batch.map(lambda v: v.flatten(0, 1))