def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RolloutBufferSamples: data = ( self.observations[batch_inds], self.actions[batch_inds], self.values[batch_inds].flatten(), self.log_probs[batch_inds].flatten(), self.advantages[batch_inds].flatten(), self.returns[batch_inds].flatten(), ) if isinstance(self.observation_space, JsonGraph): return RolloutBufferSamples(data[0],*tuple(map(self.to_torch, data[1:]))) else: return RolloutBufferSamples(*tuple(map(self.to_torch, data)))
def format_trajectories(trajectories : List[TrajectoryBufferSamples]) -> RolloutBufferSamples: observations = np.concatenate([t.observations for t in trajectories]).squeeze() contexts = np.concatenate([np.broadcast_to(t.context, (t.observations.shape[0],) + t.context.shape) for t in trajectories]) observations = np.concatenate([observations, contexts], axis=-1).squeeze() actions = np.concatenate([t.actions for t in trajectories]).squeeze() returns = np.concatenate([t.returns for t in trajectories]).squeeze() values = np.concatenate([t.values for t in trajectories]).squeeze() log_probs = np.concatenate([t.log_probs for t in trajectories]).squeeze() advantages = np.concatenate([t.advantages for t in trajectories]).squeeze() try: context_error = np.concatenate([t.context_error for t in trajectories]).squeeze() except: context_error = np.zeros(values.shape) indices = np.arange(actions.shape[0]) np.random.shuffle(indices) data = ( th.tensor(observations[indices]), th.tensor(actions[indices]), th.tensor(values[indices]), th.tensor(log_probs[indices]), th.tensor(advantages[indices]), th.tensor(returns[indices]), th.tensor(context_error[indices]) ) return RolloutBufferSamples(*tuple(map(th.tensor, data)))
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RolloutBufferSamples: data = (self.observations[batch_inds], self.actions[batch_inds], self.values[batch_inds].flatten(), self.log_probs[batch_inds].flatten(), self.advantages[batch_inds].flatten(), self.returns[batch_inds].flatten()) return RolloutBufferSamples(*tuple(map(self.to_torch, data)))