Ejemplo n.º 1
0
  def __init__(self, size, alpha):
    """Create Prioritized Replay buffer.
    Parameters
    ----------
    size: int
      Max number of transitions to store in the buffer. When the buffer
      overflows the old memories are dropped.
    alpha: float
      how much prioritization is used
      (0 - no prioritization, 1 - full prioritization)
    See Also
    --------
    ReplayBuffer.__init__
    """
    super(PrioritizedReplayBuffer, self).__init__(size)
    assert alpha >= 0
    self._alpha = alpha

    it_capacity = 1
    while it_capacity < size:
      it_capacity *= 2

    self._it_sum = SumSegmentTree(it_capacity)
    self._it_min = MinSegmentTree(it_capacity)
    self._max_priority = 1.0
Ejemplo n.º 2
0
class PrioritizedReplayBuffer(ReplayBuffer):
  def __init__(self, size, alpha):
    """Create Prioritized Replay buffer.
    Parameters
    ----------
    size: int
      Max number of transitions to store in the buffer. When the buffer
      overflows the old memories are dropped.
    alpha: float
      how much prioritization is used
      (0 - no prioritization, 1 - full prioritization)
    See Also
    --------
    ReplayBuffer.__init__
    """
    super(PrioritizedReplayBuffer, self).__init__(size)
    assert alpha >= 0
    self._alpha = alpha

    it_capacity = 1
    while it_capacity < size:
      it_capacity *= 2

    self._it_sum = SumSegmentTree(it_capacity)
    self._it_min = MinSegmentTree(it_capacity)
    self._max_priority = 1.0

  def add(self, obs_t, action, reward, obs_tp1, done):
    """See ReplayBuffer.store_effect"""
    idx = self._next_idx
    super(PrioritizedReplayBuffer, self).add(obs_t, action, reward, obs_tp1, done)
    #self._it_sum[idx] = self._max_priority ** self._alpha
    #self._it_min[idx] = self._max_priority ** self._alpha
    self._it_sum[idx] = reward + 0.01
    self._it_min[idx] = reward + 0.01


  def _sample_proportional(self, batch_size):
    res = []
    p_total = self._it_sum.sum(0, len(self._storage) - 1)
    every_range_len = p_total / batch_size
    for i in range(batch_size):
      mass = random.random() * every_range_len + i * every_range_len
      idx = self._it_sum.find_prefixsum_idx(mass)
      res.append(idx)
    return res

  def sample(self, batch_size, beta):
    """Sample a batch of experiences.
    compared to ReplayBuffer.sample
    it also returns importance weights and idxes
    of sampled experiences.
    Parameters
    ----------
    batch_size: int
      How many transitions to sample.
    beta: float
      To what degree to use importance weights
      (0 - no corrections, 1 - full correction)
    Returns
    -------
    obs_batch: np.array
      batch of observations
    act_batch: np.array
      batch of actions executed given obs_batch
    rew_batch: np.array
      rewards received as results of executing act_batch
    next_obs_batch: np.array
      next set of observations seen after executing act_batch
    done_mask: np.array
      done_mask[i] = 1 if executing act_batch[i] resulted in
      the end of an episode and 0 otherwise.
    weights: np.array
      Array of shape (batch_size,) and dtype np.float32
      denoting importance weight of each sampled transition
    idxes: np.array
      Array of shape (batch_size,) and dtype np.int32
      idexes in buffer of sampled experiences
    """
    assert beta > 0

    idxes = self._sample_proportional(batch_size)

    weights = []
    p_min = self._it_min.min() / self._it_sum.sum()
    max_weight = (p_min * len(self._storage)) ** (-beta)

    for idx in idxes:
      p_sample = self._it_sum[idx] / self._it_sum.sum()
      weight = (p_sample * len(self._storage)) ** (-beta)
      weights.append(weight / max_weight)
    weights = np.array(weights)
    encoded_sample = self._encode_sample(idxes)
    return (encoded_sample, weights, idxes)

  def update_priorities(self, idxes, priorities):
    """Update priorities of sampled transitions.
    sets priority of transition at index idxes[i] in buffer
    to priorities[i].
    Parameters
    ----------
    idxes: [int]
      List of idxes of sampled transitions
    priorities: [float]
      List of updated priorities corresponding to
      transitions at the sampled idxes denoted by
      variable `idxes`.
    """
    assert len(idxes) == len(priorities)
    for idx, priority in zip(idxes, priorities):
      assert priority > 0
      assert 0 <= idx < len(self._storage)
      self._it_sum[idx] = priority * self._alpha
      self._it_min[idx] = priority * self._alpha

      self._max_priority = max(self._max_priority, priority)