Ejemplo n.º 1
0
 def __init__(self,
              buffer_size: int,
              observation_space: spaces.Space,
              action_space: spaces.Space,
              device: Union[th.device, str] = 'cpu',
              n_envs: int = 1):
     super(BaseBuffer, self).__init__()
     self.buffer_size = buffer_size
     self.observation_space = observation_space
     self.action_space = action_space
     self.obs_shape = get_obs_shape(observation_space)
     self.action_dim = get_action_dim(action_space)
     self.pos = 0
     self.full = False
     self.device = device
     self.n_envs = n_envs
Ejemplo n.º 2
0
 def __init__(
     self,
     observation_space: spaces.Space,
     action_space: spaces.Space,
     device: Union[th.device, str] = "cpu",
     gae_lambda: float = 1,
     gamma: float = 0.99,
     num_trajectories: int = 20 # TODO: put as a parameter
 ):
     self.observation_space = observation_space
     self.action_space = action_space
     self.obs_shape = get_obs_shape(observation_space)
     self.action_dim = get_action_dim(action_space)
     self.full = False
     self.device = device
     self.gae_lambda = gae_lambda
     self.gamma = gamma
     self.num_trajectories = num_trajectories
     self.traj_idx = 0
     self.live_agents : Dict[int, int] = {} # env agent-id -> buffer-unique
     self.trajectories : Dict[int, TrajectoryBufferSamples] = {} # buffer-unique id -> trajectory
     self.num_done_trajectories = 0
Ejemplo n.º 3
0
    def __init__(
        self,
        env: VecEnv,
        buffer_size: int,
        device: Union[th.device, str] = "cpu",
        replay_buffer: Optional[DictReplayBuffer] = None,
        max_episode_length: Optional[int] = None,
        n_sampled_goal: int = 4,
        goal_selection_strategy: Union[GoalSelectionStrategy, str] = "future",
        online_sampling: bool = True,
        handle_timeout_termination: bool = True,
    ):

        super(HerReplayBuffer,
              self).__init__(buffer_size, env.observation_space,
                             env.action_space, device, env.num_envs)

        # convert goal_selection_strategy into GoalSelectionStrategy if string
        if isinstance(goal_selection_strategy, str):
            self.goal_selection_strategy = KEY_TO_GOAL_STRATEGY[
                goal_selection_strategy.lower()]
        else:
            self.goal_selection_strategy = goal_selection_strategy

        # check if goal_selection_strategy is valid
        assert isinstance(
            self.goal_selection_strategy, GoalSelectionStrategy
        ), f"Invalid goal selection strategy, please use one of {list(GoalSelectionStrategy)}"

        self.n_sampled_goal = n_sampled_goal
        # if we sample her transitions online use custom replay buffer
        self.online_sampling = online_sampling
        # compute ratio between HER replays and regular replays in percent for online HER sampling
        self.her_ratio = 1 - (1.0 / (self.n_sampled_goal + 1))
        # maximum steps in episode
        self.max_episode_length = get_time_limit(env, max_episode_length)
        # storage for transitions of current episode for offline sampling
        # for online sampling, it replaces the "classic" replay buffer completely
        her_buffer_size = buffer_size if online_sampling else self.max_episode_length

        self.env = env
        self.buffer_size = her_buffer_size

        if online_sampling:
            replay_buffer = None
        self.replay_buffer = replay_buffer
        self.online_sampling = online_sampling

        # Handle timeouts termination properly if needed
        # see https://github.com/DLR-RM/stable-baselines3/issues/284
        self.handle_timeout_termination = handle_timeout_termination

        # buffer with episodes
        # number of episodes which can be stored until buffer size is reached
        self.max_episode_stored = self.buffer_size // self.max_episode_length
        self.current_idx = 0
        # Counter to prevent overflow
        self.episode_steps = 0

        # Get shape of observation and goal (usually the same)
        self.obs_shape = get_obs_shape(
            self.env.observation_space.spaces["observation"])
        self.goal_shape = get_obs_shape(
            self.env.observation_space.spaces["achieved_goal"])

        # input dimensions for buffer initialization
        input_shape = {
            "observation": (self.env.num_envs, ) + self.obs_shape,
            "achieved_goal": (self.env.num_envs, ) + self.goal_shape,
            "desired_goal": (self.env.num_envs, ) + self.goal_shape,
            "action": (self.action_dim, ),
            "reward": (1, ),
            "next_obs": (self.env.num_envs, ) + self.obs_shape,
            "next_achieved_goal": (self.env.num_envs, ) + self.goal_shape,
            "next_desired_goal": (self.env.num_envs, ) + self.goal_shape,
            "done": (1, ),
        }
        self._observation_keys = [
            "observation", "achieved_goal", "desired_goal"
        ]
        self._buffer = {
            key: np.zeros(
                (self.max_episode_stored, self.max_episode_length, *dim),
                dtype=np.float32)
            for key, dim in input_shape.items()
        }
        # Store info dicts are it can be used to compute the reward (e.g. continuity cost)
        self.info_buffer = [
            deque(maxlen=self.max_episode_length)
            for _ in range(self.max_episode_stored)
        ]
        # episode length storage, needed for episodes which has less steps than the maximum length
        self.episode_lengths = np.zeros(self.max_episode_stored,
                                        dtype=np.int64)