class PrioritizedReplayBuffer(ReplayBuffer):
    def __init__(self, size, alpha):
        """Create Prioritized Replay buffer.

        Parameters
        ----------
        size: int
            Max number of transitions to store in the buffer. When the buffer
            overflows the old memories are dropped.
        alpha: float
            how much prioritization is used
            (0 - no prioritization, 1 - full prioritization)

        See Also
        --------
        ReplayBuffer.__init__
        """
        super(PrioritizedReplayBuffer, self).__init__(size)
        assert alpha > 0
        self._alpha = alpha

        it_capacity = 1
        while it_capacity < size:
            it_capacity *= 2

        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)
        self._max_priority = 1.0

    def add(self, *args, **kwargs):
        """See ReplayBuffer.store_effect"""
        idx = self._next_idx
        super().add(*args, **kwargs)
        self._it_sum[idx] = self._max_priority ** self._alpha
        self._it_min[idx] = self._max_priority ** self._alpha

    def _sample_proportional(self, batch_size):
        res = []
        for _ in range(batch_size):
            # TODO(szymon): should we ensure no repeats?
            mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1)
            idx = self._it_sum.find_prefixsum_idx(mass)
            res.append(idx)
        return res

    def sample(self, batch_size, beta):
        """Sample a batch of experiences.

        compared to ReplayBuffer.sample
        it also returns importance weights and idxes
        of sampled experiences.


        Parameters
        ----------
        batch_size: int
            How many transitions to sample.
        beta: float
            To what degree to use importance weights
            (0 - no corrections, 1 - full correction)

        Returns
        -------
        obs_batch: np.array
            batch of observations
        act_batch: np.array
            batch of actions executed given obs_batch
        rew_batch: np.array
            rewards received as results of executing act_batch
        next_obs_batch: np.array
            next set of observations seen after executing act_batch
        done_mask: np.array
            done_mask[i] = 1 if executing act_batch[i] resulted in
            the end of an episode and 0 otherwise.
        weights: np.array
            Array of shape (batch_size,) and dtype np.float32
            denoting importance weight of each sampled transition
        idxes: np.array
            Array of shape (batch_size,) and dtype np.int32
            idexes in buffer of sampled experiences
        """
        assert beta > 0

        idxes = self._sample_proportional(batch_size)

        weights = []
        p_min = self._it_min.min() / self._it_sum.sum()
        max_weight = (p_min * len(self._storage)) ** (-beta)

        for idx in idxes:
            p_sample = self._it_sum[idx] / self._it_sum.sum()
            weight = (p_sample * len(self._storage)) ** (-beta)
            weights.append(weight / max_weight)
        weights = np.array(weights)
        encoded_sample = self._encode_sample(idxes)
        return tuple(list(encoded_sample) + [weights, idxes])

    def update_priorities(self, idxes, priorities):
        """Update priorities of sampled transitions.

        sets priority of transition at index idxes[i] in buffer
        to priorities[i].

        Parameters
        ----------
        idxes: [int]
            List of idxes of sampled transitions
        priorities: [float]
            List of updated priorities corresponding to
            transitions at the sampled idxes denoted by
            variable `idxes`.
        """
        assert len(idxes) == len(priorities)
        for idx, priority in zip(idxes, priorities):
            assert priority > 0
            assert 0 <= idx < len(self._storage)
            self._it_sum[idx] = priority ** self._alpha
            self._it_min[idx] = priority ** self._alpha

            self._max_priority = max(self._max_priority, priority)
Beispiel #2
0
class PrioritizedReplayBuffer(ReplayBuffer):
    def __init__(self, size, alpha):
        """Create Prioritized Replay buffer.

        Parameters
        ----------
        size: int
            Max number of transitions to store in the buffer. When the buffer
            overflows the old memories are dropped.
        alpha: float
            how much prioritization is used
            (0 - no prioritization, 1 - full prioritization)

        See Also
        --------
        ReplayBuffer.__init__
        """
        super(PrioritizedReplayBuffer, self).__init__(size)
        assert alpha >= 0
        self._alpha = alpha

        it_capacity = 1
        while it_capacity < size:
            it_capacity *= 2

        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)
        self._max_priority = 1.0

    def add(self, *args, **kwargs):
        """See ReplayBuffer.store_effect"""
        idx = self._next_idx
        super().add(*args, **kwargs)
        self._it_sum[idx] = self._max_priority**self._alpha
        self._it_min[idx] = self._max_priority**self._alpha

    def _sample_proportional(self, batch_size):
        res = []
        for _ in range(batch_size):
            # TODO(szymon): should we ensure no repeats?
            mass = random.random() * self._it_sum.sum(0,
                                                      len(self._storage) - 1)
            idx = self._it_sum.find_prefixsum_idx(mass)
            res.append(idx)
        return res

    def sample(self, batch_size, beta):
        """Sample a batch of experiences.

        compared to ReplayBuffer.sample
        it also returns importance weights and idxes
        of sampled experiences.


        Parameters
        ----------
        batch_size: int
            How many transitions to sample.
        beta: float
            To what degree to use importance weights
            (0 - no corrections, 1 - full correction)

        Returns
        -------
        obs_batch: np.array
            batch of observations
        act_batch: np.array
            batch of actions executed given obs_batch
        rew_batch: np.array
            rewards received as results of executing act_batch
        next_obs_batch: np.array
            next set of observations seen after executing act_batch
        done_mask: np.array
            done_mask[i] = 1 if executing act_batch[i] resulted in
            the end of an episode and 0 otherwise.
        weights: np.array
            Array of shape (batch_size,) and dtype np.float32
            denoting importance weight of each sampled transition
        idxes: np.array
            Array of shape (batch_size,) and dtype np.int32
            idexes in buffer of sampled experiences
        """
        assert beta > 0

        idxes = self._sample_proportional(batch_size)

        weights = []
        p_min = self._it_min.min() / self._it_sum.sum()
        max_weight = (p_min * len(self._storage))**(-beta)

        for idx in idxes:
            p_sample = self._it_sum[idx] / self._it_sum.sum()
            weight = (p_sample * len(self._storage))**(-beta)
            weights.append(weight / max_weight)
        weights = np.array(weights)
        encoded_sample = self._encode_sample(idxes)
        return tuple(list(encoded_sample) + [weights, idxes])

    def update_priorities(self, idxes, priorities):
        """Update priorities of sampled transitions.

        sets priority of transition at index idxes[i] in buffer
        to priorities[i].

        Parameters
        ----------
        idxes: [int]
            List of idxes of sampled transitions
        priorities: [float]
            List of updated priorities corresponding to
            transitions at the sampled idxes denoted by
            variable `idxes`.
        """
        assert len(idxes) == len(priorities)
        for idx, priority in zip(idxes, priorities):
            assert priority > 0
            assert 0 <= idx < len(self._storage)
            self._it_sum[idx] = priority**self._alpha
            self._it_min[idx] = priority**self._alpha

            self._max_priority = max(self._max_priority, priority)
Beispiel #3
0
class ReplayMemory:
    def __init__(self, replay_size, alpha=0.6):
        self.replay_size = replay_size
        self.cnt = 0
        self._alpha = alpha
        it_capacity = 1
        while it_capacity < replay_size:
            it_capacity *= 2

        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)
        self._max_priority = 1.0
        self._storage = []
        self._maxsize = replay_size
        self._next_idx = 0

    def add(self, data):
        #new_data = []
        #for i in data:
        #    i.wait_to_read()
        #    new_data.append(copyto(i))
        
        if self._next_idx >= len(self._storage):
            self._storage.append(data)
            #print self._storage
        else:
            self._storage[self._next_idx] = data
        self._next_idx = (self._next_idx + 1) % self._maxsize
        idx = self._next_idx
        self._it_sum[idx] = self._max_priority ** self._alpha
        self._it_min[idx] = self._max_priority ** self._alpha


    def _sample_proportional(self, batch_size):
        res = []
        for _ in range(batch_size):
            mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1)
            idx = self._it_sum.find_prefixsum_idx(mass)
            res.append(idx)
        return res

    def sample(self, batch_size, beta=0.4):
        assert beta > 0

        idxes = self._sample_proportional(batch_size)

        weights = []
        p_min = self._it_min.min() / self._it_sum.sum()
        max_weight = (p_min * len(self._storage)) ** (-beta)

        for idx in idxes:
            p_sample = self._it_sum[idx] / self._it_sum.sum()
            weight = (p_sample * len(self._storage)) ** (-beta)
            weights.append(weight / max_weight)
        #print self._it_min.min(), weights
        weights = np.array(weights)
        weights /= np.sum(weights)
        ret = []
        for i in xrange(batch_size):
            ret.append(self._storage[idxes[i]])
        return (ret, idxes, weights)

    def update_priorities(self, idxes, priorities):
        assert len(idxes) == len(priorities)
        for idx, priority in zip(idxes, priorities):
            #print priority
            assert priority > 0
            assert 0 <= idx < len(self._storage)
            self._it_sum[idx] = priority ** self._alpha
            self._it_min[idx] = priority ** self._alpha

            self._max_priority = max(self._max_priority, priority)
class PrioritizedReplayBuffer(SimpleReplayBuffer):
    def __init__(self, MAX_LEN, alpha: float = 0.6):
        """
        Create Prioritized Replay buffer.

        Parameters
        ----------
        size: int
            Max number of transitions to store in the buffer. When the buffer
            overflows the old memories are dropped.
        alpha: float
            how much prioritization is used
            (0 - no prioritization, 1 - full prioritization)

        See Also
        --------
        SimpleReplayBuffer.__init__
        """
        super(PrioritizedReplayBuffer, self).__init__(MAX_LEN)
        assert alpha >= 0
        self._alpha = alpha

        it_capacity = 1
        while it_capacity < MAX_LEN:
            it_capacity *= 2

        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)
        self._max_priority = 1.0

    def add(self, experience: Experience):
        """See SimpleReplayBuffer.store_effect"""
        idx = self._next_idx
        super().add(experience)
        self._it_sum[idx] = self._max_priority**self._alpha
        self._it_min[idx] = self._max_priority**self._alpha

    def _sample_proportional(self, batch_size: int) -> List[int]:
        res = []
        p_total = self._it_sum.sum(0, len(self._storage) - 1)
        every_range_len = p_total / batch_size
        for i in range(batch_size):
            mass = random() * every_range_len + i * every_range_len
            idx = self._it_sum.find_prefixsum_idx(mass)
            res.append(idx)
        return res

    def sample(
        self,
        batch_size: int,
        beta: float = 0.4
    ) -> Tuple[List[Experience], np.ndarray, List[int]]:  # type: ignore
        """Sample a batch of experiences.

        compared to SimpleReplayBuffer.sample
        it also returns importance weights and idxes
        of sampled experiences.


        Parameters
        ----------
        batch_size: int
            How many transitions to sample.
        beta: float
            To what degree to use importance weights
            (0 - no corrections, 1 - full correction)

        Returns
        -------
        experiences: List[Experience]
            batch of experiences
        weights: np.array
            Array of shape (batch_size,) and dtype np.float32
            denoting importance weight of each sampled transition
        idxes: np.array
            Array of shape (batch_size,) and dtype np.int32
            idexes in buffer of sampled experiences
        """
        assert beta > 0

        idxes = self._sample_proportional(batch_size)

        weights = []
        p_min = self._it_min.min() / self._it_sum.sum()
        max_weight = (p_min * len(self._storage))**(-beta)

        for idx in idxes:
            p_sample = self._it_sum[idx] / self._it_sum.sum()
            weight = (p_sample * len(self._storage))**(-beta)
            weights.append(weight / max_weight)
        weights = np.array(weights)
        encoded_sample = self._encode_sample(idxes)
        return (encoded_sample, weights, idxes)

    def update_priorities(self, idxes, priorities):
        """Update priorities of sampled transitions.

        sets priority of transition at index idxes[i] in buffer
        to priorities[i].

        Parameters
        ----------
        idxes: [int]
            List of idxes of sampled transitions
        priorities: [float]
            List of updated priorities corresponding to
            transitions at the sampled idxes denoted by
            variable `idxes`.
        """
        assert len(idxes) == len(priorities)
        for idx, priority in zip(idxes, priorities):
            assert priority > 0
            assert 0 <= idx < len(self._storage)
            self._it_sum[idx] = priority**self._alpha
            self._it_min[idx] = priority**self._alpha

            self._max_priority = max(self._max_priority, priority)
class PrioritizedReplayMemory(object):
    def __init__(self,
                 capacity=100000,
                 priority_fraction=0.0,
                 discount_gamma_game_reward=1.0,
                 discount_gamma_graph_reward=1.0,
                 discount_gamma_count_reward=1.0,
                 accumulate_reward_from_final=False):
        # prioritized replay memory
        self._storage = []
        self.capacity = capacity
        self._next_idx = 0

        assert priority_fraction >= 0
        self._alpha = priority_fraction

        it_capacity = 1
        while it_capacity < capacity:
            it_capacity *= 2

        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)
        self._max_priority = 1.0
        self.discount_gamma_game_reward = discount_gamma_game_reward
        self.discount_gamma_graph_reward = discount_gamma_graph_reward
        self.discount_gamma_count_reward = discount_gamma_count_reward
        self.accumulate_reward_from_final = accumulate_reward_from_final

    def __len__(self):
        return len(self._storage)

    @property
    def storage(self):
        """[(np.ndarray, float, float, np.ndarray, bool)]: content of the replay buffer"""
        return self._storage

    @property
    def buffer_size(self):
        """float: Max capacity of the buffer"""
        return self.capacity

    def can_sample(self, n_samples):
        """
        Check if n_samples samples can be sampled
        from the buffer.
        :param n_samples: (int)
        :return: (bool)
        """
        return len(self) >= n_samples

    def is_full(self):
        """
        Check whether the replay buffer is full or not.
        :return: (bool)
        """
        return len(self) == self.buffer_size

    def add(self, *args):
        """
        add a new transition to the buffer
        """
        idx = self._next_idx
        data = Transition(*args)

        if self._next_idx >= len(self._storage):
            self._storage.append(data)
        else:
            self._storage[self._next_idx] = data
        self._next_idx = (self._next_idx + 1) % self.capacity
        self._it_sum[idx] = self._max_priority**self._alpha
        self._it_min[idx] = self._max_priority**self._alpha

    def get_next_final_pos(self, which_memory, head):
        i = head
        while True:
            if i >= len(self._storage):
                return None
            if self._storage[i].is_final:
                return i
            i += 1
        return None

    def _get_single_transition(self, idx, n):
        assert n > 0
        head = idx
        # if n is 1, then head can't be is_final
        if n == 1:
            if self._storage[head].is_final:
                return None
        #  if n > 1, then all except tail can't be is_final
        else:
            if np.any([item.is_final
                       for item in self._storage[head:head + n]]):
                return None

        next_final = self.get_next_final_pos(self._storage, head)
        if next_final is None:
            return None

        # all good
        obs = self._storage[head].observation_list
        candidate = self._storage[head].action_candidate_list
        chosen_indices = self._storage[head].chosen_indices
        graph_triplets = self._storage[head].graph_triplets

        next_obs = self._storage[head + n].observation_list
        next_candidate = self._storage[head + n].action_candidate_list
        next_graph_triplets = self._storage[head + n].graph_triplets

        tmp = next_final - head + 1 if self.accumulate_reward_from_final else n + 1

        rewards_up_to_next_final = [
            self.discount_gamma_game_reward**i * self._storage[head + i].reward
            for i in range(tmp)
        ]
        reward = torch.sum(torch.stack(rewards_up_to_next_final))

        graph_rewards_up_to_next_final = [
            self.discount_gamma_graph_reward**i *
            self._storage[head + i].graph_reward for i in range(tmp)
        ]
        graph_reward = torch.sum(torch.stack(graph_rewards_up_to_next_final))

        count_rewards_up_to_next_final = [
            self.discount_gamma_count_reward**i *
            self._storage[head + i].count_reward for i in range(tmp)
        ]
        count_reward = torch.sum(torch.stack(count_rewards_up_to_next_final))

        return (obs, candidate, chosen_indices, graph_triplets,
                reward + graph_reward + count_reward, next_obs, next_candidate,
                next_graph_triplets)

    def _encode_sample(self, idxes, ns):
        actual_indices, actual_ns = [], []
        obs, candidate, chosen_indices, graph_triplets, reward, next_obs, next_candidate, next_graph_triplets = [], [], [], [], [], [], [], []
        for i, n in zip(idxes, ns):
            t = self._get_single_transition(i, n)
            if t is None:
                continue
            actual_indices.append(i)
            actual_ns.append(n)
            obs.append(t[0])
            candidate.append(t[1])
            chosen_indices.append(t[2])
            graph_triplets.append(t[3])
            reward.append(t[4])
            next_obs.append(t[5])
            next_candidate.append(t[6])
            next_graph_triplets.append(t[7])
        if len(actual_indices) == 0:
            return None
        chosen_indices = np.array(chosen_indices)  # batch
        reward = torch.stack(reward, 0)  # batch
        actual_ns = np.array(actual_ns)

        return [
            obs, candidate, chosen_indices, graph_triplets, reward, next_obs,
            next_candidate, next_graph_triplets, actual_indices, actual_ns
        ]

    def sample(self, batch_size, beta=0, multi_step=1):

        assert beta > 0

        idxes = self._sample_proportional(batch_size)
        weights = []
        p_min = self._it_min.min() / self._it_sum.sum()
        max_weight = (p_min * len(self._storage))**(-beta)

        # sample n
        ns = np.random.randint(1, multi_step + 1, size=batch_size)
        encoded_sample = self._encode_sample(idxes, ns)
        if encoded_sample is None:
            return None
        actual_indices = encoded_sample[-2]
        for idx in actual_indices:
            p_sample = self._it_sum[idx] / self._it_sum.sum()
            weight = (p_sample * len(self._storage))**(-beta)
            weights.append(weight / max_weight)
        weights = np.array(weights)

        return encoded_sample + [weights]

    def _get_single_sequence_transition(self, idx, sample_history_length):
        assert sample_history_length > 0
        head = idx
        # if n is 1, then head can't be is_final
        if sample_history_length == 1:
            if self._storage[head].is_final:
                return None
        #  if n > 1, then all except tail can't be is_final
        else:
            if np.any([
                    item.is_final
                    for item in self._storage[head:head +
                                              sample_history_length]
            ]):
                return None

        next_final = self.get_next_final_pos(self._storage, head)
        if next_final is None:
            return None

        # all good
        res = []
        for m in range(sample_history_length):
            obs = self._storage[head + m].observation_list
            candidate = self._storage[head + m].action_candidate_list
            chosen_indices = self._storage[head + m].chosen_indices
            graph_triplets = self._storage[head + m].graph_triplets

            next_obs = self._storage[head + m + 1].observation_list
            next_candidate = self._storage[head + m + 1].action_candidate_list
            next_graph_triplets = self._storage[head + m + 1].graph_triplets

            tmp = next_final - (
                head + m) + 1 if self.accumulate_reward_from_final else 1

            rewards_up_to_next_final = [
                self.discount_gamma_game_reward**i *
                self._storage[head + m + i].reward for i in range(tmp)
            ]
            reward = torch.sum(torch.stack(rewards_up_to_next_final))

            graph_rewards_up_to_next_final = [
                self.discount_gamma_graph_reward**i *
                self._storage[head + m + i].graph_reward for i in range(tmp)
            ]
            graph_reward = torch.sum(
                torch.stack(graph_rewards_up_to_next_final))

            count_rewards_up_to_next_final = [
                self.discount_gamma_count_reward**i *
                self._storage[head + m + i].count_reward for i in range(tmp)
            ]
            count_reward = torch.sum(
                torch.stack(count_rewards_up_to_next_final))

            res.append([
                obs, candidate, chosen_indices, graph_triplets,
                reward + graph_reward + count_reward, next_obs, next_candidate,
                next_graph_triplets
            ])
        return res

    def _encode_sample_sequence(self, idxes, sample_history_length):
        assert sample_history_length > 0
        res = []
        for _ in range(sample_history_length):
            tmp = []
            for i in range(8):
                tmp.append([])
            res.append(tmp)

        actual_indices = []
        # obs, candidate, chosen_indices, graph_triplets, reward, next_obs, next_candidate, next_graph_triplets
        for i in idxes:
            t = self._get_single_sequence_transition(i, sample_history_length)
            if t is None:
                continue
            actual_indices.append(i)
            for step in range(sample_history_length):
                t_s = t[step]
                res[step][0].append(t_s[0])
                res[step][1].append(t_s[1])
                res[step][2].append(t_s[2])
                res[step][3].append(t_s[3])
                res[step][4].append(t_s[4])
                res[step][5].append(t_s[5])
                res[step][6].append(t_s[6])
                res[step][7].append(t_s[7])

        if len(actual_indices) == 0:
            return None
        for i in range(sample_history_length):
            res[i][2] = np.array(res[i][2])  # batch
            res[i][4] = torch.stack(res[i][4], 0)  # batch

        return res + [actual_indices]

    def sample_sequence(self, batch_size, beta=0, sample_history_length=1):
        assert beta > 0

        idxes = self._sample_proportional(batch_size)
        res_weights = []
        p_min = self._it_min.min() / self._it_sum.sum()
        max_weight = (p_min * len(self._storage))**(-beta)

        encoded_sample = self._encode_sample_sequence(idxes,
                                                      sample_history_length)
        if encoded_sample is None:
            return None
        actual_indices = encoded_sample[-1]
        for _h in range(sample_history_length):
            tmp_weights = []
            for idx in actual_indices:
                p_sample = self._it_sum[idx + _h] / self._it_sum.sum()
                weight = (p_sample * len(self._storage))**(-beta)
                tmp_weights.append(weight / max_weight)
            tmp_weights = np.array(tmp_weights)
            res_weights.append(tmp_weights)

        return encoded_sample + [res_weights]

    def _sample_proportional(self, batch_size):
        res = []
        for _ in range(batch_size):
            mass = random.random() * self._it_sum.sum(0,
                                                      len(self._storage) - 1)
            idx = self._it_sum.find_prefixsum_idx(mass)
            res.append(idx)
        return res

    def update_priorities(self, idxes, priorities):
        """
        Update priorities of sampled transitions.
        sets priority of transition at index idxes[i] in buffer
        to priorities[i].
        :param idxes: ([int]) List of idxes of sampled transitions
        :param priorities: ([float]) List of updated priorities corresponding to transitions at the sampled idxes
            denoted by variable `idxes`.
        """
        assert len(idxes) == len(priorities)
        for idx, priority in zip(idxes, priorities):
            assert priority > 0
            assert 0 <= idx < len(self._storage)
            self._it_sum[idx] = priority**self._alpha
            self._it_min[idx] = priority**self._alpha

            self._max_priority = max(self._max_priority, priority)

    def avg_rewards(self):
        if len(self._storage) == 0:
            return 0.0
        rewards = [self._storage[i].reward for i in range(len(self._storage))]
        return to_np(torch.mean(torch.stack(rewards)))
Beispiel #6
0
class PrioritizedReplayMemory:
    def __init__(self, size, alpha=0.6, beta_start=0.4, beta_frames=100000):
        super(PrioritizedReplayMemory, self).__init__()
        self._storage = []
        self._maxsize = size
        self._next_idx = 0

        assert alpha >= 0
        self._alpha = alpha

        self.beta_start = beta_start
        self.beta_frames = beta_frames
        self.frame = 1

        it_capacity = 1
        while it_capacity < size:
            it_capacity *= 2

        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)
        self._max_priority = 1.0
        self.experience = namedtuple(
            "Experience",
            field_names=["state", "action", "reward", "next_state", "done"])

    def beta_by_frame(self, frame_idx):
        return min(
            1.0, self.beta_start + frame_idx *
            (1.0 - self.beta_start) / self.beta_frames)

    def push(self, state, action, reward, next_state, done):
        idx = self._next_idx

        if self._next_idx >= len(self._storage):
            self._storage.append(
                self.experience(state, action, reward, next_state, done))
        else:
            self._storage[self._next_idx] = self.experience(
                state, action, reward, next_state, done)
        self._next_idx = (self._next_idx + 1) % self._maxsize

        self._it_sum[idx] = self._max_priority**self._alpha
        self._it_min[idx] = self._max_priority**self._alpha

    def _encode_sample(self, idxes):
        states = torch.from_numpy(
            np.vstack([self._storage[i].state
                       for i in idxes])).float().to(device)
        actions = torch.from_numpy(
            np.vstack([self._storage[i].action
                       for i in idxes])).float().to(device)
        rewards = torch.from_numpy(
            np.vstack([self._storage[i].reward
                       for i in idxes])).float().to(device)
        next_states = torch.from_numpy(
            np.vstack([self._storage[i].next_state
                       for i in idxes])).float().to(device)
        dones = torch.from_numpy(
            np.vstack([self._storage[i].done
                       for i in idxes]).astype(np.uint8)).float().to(device)
        return (states, actions, rewards, next_states, dones)

    def _sample_proportional(self, batch_size):
        res = []
        for _ in range(batch_size):
            mass = random.random() * self._it_sum.sum(0,
                                                      len(self._storage) - 1)
            idx = self._it_sum.find_prefixsum_idx(mass)
            res.append(idx)
        return res

    def sample(self, batch_size):
        idxes = self._sample_proportional(batch_size)

        weights = []

        #find smallest sampling prob: p_min = smallest priority^alpha / sum of priorities^alpha
        p_min = self._it_min.min() / self._it_sum.sum()

        beta = self.beta_by_frame(self.frame)
        self.frame += 1

        #max_weight given to smallest prob
        max_weight = (p_min * len(self._storage))**(-beta)

        for idx in idxes:
            p_sample = self._it_sum[idx] / self._it_sum.sum()
            weight = (p_sample * len(self._storage))**(-beta)
            weights.append(weight / max_weight)
        weights = torch.tensor(weights, device=device, dtype=torch.float)
        encoded_sample = self._encode_sample(idxes)
        return encoded_sample, idxes, weights

    def update_priorities(self, idxes, priorities):
        assert len(idxes) == len(priorities)
        for idx, priority in zip(idxes, priorities):
            assert 0 <= idx < len(self._storage)
            self._it_sum[idx] = (priority + 1e-5)**self._alpha
            self._it_min[idx] = (priority + 1e-5)**self._alpha

            self._max_priority = max(self._max_priority, (priority + 1e-5))
Beispiel #7
0
class PrioritizedReplayBuffer(ReplayBuffer):
    def __init__(self, size, alpha):
        """Create Prioritized Replay buffer.

        Parameters
        ----------
        size: int
            Max number of transitions to store in the buffer. When the buffer
            overflows the old memories are dropped.
        alpha: float
            how much prioritization is used
            (0 - no prioritization, 1 - full prioritization)

        See Also
        --------
        ReplayBuffer.__init__
        """
        super(PrioritizedReplayBuffer, self).__init__(size)
        assert alpha > 0
        self._alpha = alpha

        it_capacity = 1
        while it_capacity < size:
            it_capacity *= 2

        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)
        self._max_priority = 1.0

    def add(self, *args, **kwargs):
        """See ReplayBuffer.store_effect"""
        idx = self._next_idx
        super().add(*args, **kwargs)
        self._it_sum[idx] = self._max_priority**self._alpha
        self._it_min[idx] = self._max_priority**self._alpha

    def _sample_proportional(self, batch_size):
        res = []
        for _ in range(batch_size):
            # TODO(szymon): should we ensure no repeats?
            mass = random.random() * self._it_sum.sum(0,
                                                      len(self._storage) - 1)
            idx = self._it_sum.find_prefixsum_idx(mass)
            res.append(idx)
        return res

    def sample(self, batch_size, beta):
        assert beta > 0

        idxes = self._sample_proportional(batch_size)

        weights = []
        p_min = self._it_min.min() / self._it_sum.sum()
        max_weight = (p_min * len(self._storage))**(-beta)

        for idx in idxes:
            p_sample = self._it_sum[idx] / self._it_sum.sum()
            weight = (p_sample * len(self._storage))**(-beta)
            weights.append(weight / max_weight)
        weights = np.array(weights)
        encoded_sample = self._encode_sample(idxes)
        return tuple(list(encoded_sample) + [weights, idxes])

    def update_priorities(self, idxes, priorities):
        """Update priorities of sampled transitions.

        sets priority of transition at index idxes[i] in buffer
        to priorities[i].

        Parameters
        ----------
        idxes: [int]
            List of idxes of sampled transitions
        priorities: [float]
            List of updated priorities corresponding to
            transitions at the sampled idxes denoted by
            variable `idxes`.
        """
        assert len(idxes) == len(priorities)
        for idx, priority in zip(idxes, priorities):
            assert priority > 0
            assert 0 <= idx < len(self._storage)
            self._it_sum[idx] = priority**self._alpha
            self._it_min[idx] = priority**self._alpha

            self._max_priority = max(self._max_priority, priority)
Beispiel #8
0
class PrioritizedReplayBuffer(ReplayBuffer):
    """
    Adapt from https://github.com/hill-a/stable-baselines/blob/master/stable_baselines/common/buffers.py
    """
    def __init__(self, obs_space, action_space, capacity, exponent, device, optimize_memory_usage=False):
        super().__init__(obs_space, action_space, capacity, device,
                         optimize_memory_usage=optimize_memory_usage)
        assert exponent >= 0
        self.exponent = exponent

        it_capacity = 1
        while it_capacity < self.capacity:
            it_capacity *= 2

        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)
        self._max_priority = 1.0

    def _sample_proportional(self, batch_size):
        total = self._it_sum.sum(0, len(self) - 1)
        mass = np.random.random(size=batch_size) * total
        idx = self._it_sum.find_prefixsum_idx(mass)

        # replace idx == self.idx
        if self.full and self.optimize_memory_usage:
            while np.any(idx == self.idx):
                replace_mass = np.random.random(len(idx == self.idx)) * total
                replace_idx = self._it_sum.find_prefixsum_idx(replace_mass)
                idx[idx == self.idx] = replace_idx

        return idx

    def add(self, obs, action, reward, next_obs, done):
        idx = self.idx
        super().add(obs, action, reward, next_obs, done)
        self._it_sum[idx] = self._max_priority ** self.exponent
        self._it_min[idx] = self._max_priority ** self.exponent

    def sample(self, batch_size, beta=0):
        assert beta >= 0

        idxes = self._sample_proportional(batch_size)
        p_min = self._it_min.min() / self._it_sum.sum()
        max_weight = (p_min * len(self)) ** (-beta)
        p_sample = self._it_sum[idxes] / self._it_sum.sum()
        weights = (p_sample * len(self)) ** (-beta) / max_weight
        obses, actions, rewards, next_obses, not_dones = self._sample(idxes)

        priority_kwargs = {
            'weights': weights,
            'idxes': idxes
        }

        return obses, actions, rewards, next_obses, not_dones, priority_kwargs

    def update_priorities(self, idxes, priorities):
        assert len(idxes) == len(priorities)
        assert np.min(priorities) > 0
        assert np.min(idxes) >= 0
        assert np.max(idxes) < len(self)
        self._it_sum[idxes] = priorities ** self.exponent
        self._it_min[idxes] = priorities ** self.exponent

        self._max_priority = max(self._max_priority, np.max(priorities))