예제 #1
0
    def __init__(self, size, alpha):
        """Create Prioritized Replay buffer.
    Parameters
    ----------
    size: int
        Max number of transitions to store in the buffer. When the buffer
        overflows the old memories are dropped.
    alpha: float
        how much prioritization is used
        (0 - no prioritization, 1 - full prioritization)
    See Also
    --------
    ReplayBuffer.__init__
    """
        super(PrioritizedReplayBuffer, self).__init__(size)
        assert alpha > 0
        self._alpha = alpha

        it_capacity = 1
        while it_capacity < size:
            it_capacity *= 2

        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)
        self._max_priority = 1.0
예제 #2
0
    def clean(self):
        buffer_size, state_dim, obs_space, action_shape = self.capacity, self.state_dim, self.obs_space, self.action_shape
        self.curr_capacity = 0
        self.pointer = 0

        self.query_buffer = np.zeros((buffer_size, state_dim))
        self._q_values = -np.inf * np.ones(buffer_size + 1)
        self.returns = -np.inf * np.ones(buffer_size + 1)
        self.replay_buffer = np.empty((buffer_size, ) + obs_space.shape,
                                      np.float32)
        self.action_buffer = np.empty((buffer_size, ) + action_shape,
                                      np.float32)
        self.reward_buffer = np.empty((buffer_size, ), np.float32)
        self.steps = np.empty((buffer_size, ), np.int)
        self.done_buffer = np.empty((buffer_size, ), np.bool)
        self.truly_done_buffer = np.empty((buffer_size, ), np.bool)
        self.next_id = -1 * np.ones(buffer_size)
        self.prev_id = [[] for _ in range(buffer_size)]
        self.ddpg_q_values = -np.inf * np.ones(buffer_size)
        self.contra_count = np.ones((buffer_size, ))
        self.lru = np.zeros(buffer_size)
        self.time = 0

        it_capacity = 1
        while it_capacity < buffer_size:
            it_capacity *= 2

        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)
        self._max_priority = 1.0
        self.end_points = []
예제 #3
0
def test_prefixsum_idx2():
    """
    test Segment Tree data structure
    """
    tree = SumSegmentTree(4)

    tree[np.array([0, 1, 2, 3])] = [0.5, 1.0, 1.0, 3.0]

    assert tree.find_prefixsum_idx(0.00) == 0
    assert tree.find_prefixsum_idx(0.55) == 1
    assert tree.find_prefixsum_idx(0.99) == 1
    assert tree.find_prefixsum_idx(1.51) == 2
    assert tree.find_prefixsum_idx(3.00) == 3
    assert tree.find_prefixsum_idx(5.50) == 3

    tree = SumSegmentTree(4)

    tree[0] = 0.5
    tree[1] = 1.0
    tree[2] = 1.0
    tree[3] = 3.0

    assert tree.find_prefixsum_idx(0.00) == 0
    assert tree.find_prefixsum_idx(0.55) == 1
    assert tree.find_prefixsum_idx(0.99) == 1
    assert tree.find_prefixsum_idx(1.51) == 2
    assert tree.find_prefixsum_idx(3.00) == 3
    assert tree.find_prefixsum_idx(5.50) == 3
예제 #4
0
    def __init__(self, size, alpha):
        super(PrioritizedReplayBuffer, self).__init__(size)
        assert alpha >= 0
        self._alpha = alpha

        it_capacity = 1
        while it_capacity < size:
            it_capacity *= 2

        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)
        self._max_priority = 1.0
예제 #5
0
def test_prefixsum_idx():
    """
    test Segment Tree data structure
    """
    tree = SumSegmentTree(4)

    tree[2] = 1.0
    tree[3] = 3.0

    assert tree.find_prefixsum_idx(0.0) == 2
    assert tree.find_prefixsum_idx(0.5) == 2
    assert tree.find_prefixsum_idx(0.99) == 2
    assert tree.find_prefixsum_idx(1.01) == 3
    assert tree.find_prefixsum_idx(3.00) == 3
    assert tree.find_prefixsum_idx(4.00) == 3
 def __init__(self,
              capacity,
              alpha,
              transition_small_epsilon=1e-6,
              demo_epsilon=0.2):
     super(PrioritizedMemory, self).__init__(capacity)
     assert alpha > 0
     self._alpha = alpha
     self._transition_small_epsilon = transition_small_epsilon
     self._demo_epsilon = demo_epsilon
     it_capacity = 1
     while it_capacity < self.capacity:
         it_capacity *= 2  # Size must be power of 2
     self._it_sum = SumSegmentTree(it_capacity)
     self._it_min = MinSegmentTree(it_capacity)
     self._max_priority = 100.0
     self.help = []
예제 #7
0
def test_tree_set():
    """
    test Segment Tree data structure
    """
    tree = SumSegmentTree(4)

    tree[2] = 1.0
    tree[3] = 3.0

    assert np.isclose(tree.sum(), 4.0)
    assert np.isclose(tree.sum(0, 2), 0.0)
    assert np.isclose(tree.sum(0, 3), 1.0)
    assert np.isclose(tree.sum(2, 3), 1.0)
    assert np.isclose(tree.sum(2, -1), 1.0)
    assert np.isclose(tree.sum(2, 4), 4.0)
예제 #8
0
def test_tree_set_overlap():
    """
    test Segment Tree data structure
    """
    tree = SumSegmentTree(4)

    tree[2] = 1.0
    tree[2] = 3.0

    assert np.isclose(tree.sum(), 3.0)
    assert np.isclose(tree.sum(2, 3), 3.0)
    assert np.isclose(tree.sum(2, -1), 3.0)
    assert np.isclose(tree.sum(2, 4), 3.0)
    assert np.isclose(tree.sum(1, 2), 0.0)
예제 #9
0
class PrioritizedReplayBuffer(ReplayBuffer):
    def __init__(self, size, alpha):
        """
        Create Prioritized Replay buffer.

        See Also ReplayBuffer.__init__

        :param size: (int) Max number of transitions to store in the buffer. When the buffer overflows the old memories
            are dropped.
        :param alpha: (float) how much prioritization is used (0 - no prioritization, 1 - full prioritization)
        """
        super(PrioritizedReplayBuffer, self).__init__(size)
        assert alpha >= 0
        self._alpha = alpha

        it_capacity = 1
        while it_capacity < size:
            it_capacity *= 2

        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)
        self._max_priority = 1.0

    def add(self, obs_t, action, reward, obs_tp1, done):
        """
        add a new transition to the buffer

        :param obs_t: (Any) the last observation
        :param action: ([float]) the action
        :param reward: (float) the reward of the transition
        :param obs_tp1: (Any) the current observation
        :param done: (bool) is the episode done
        """
        idx = self._next_idx
        super().add(obs_t, action, reward, obs_tp1, done)
        self._it_sum[idx] = self._max_priority**self._alpha
        self._it_min[idx] = self._max_priority**self._alpha

    def _sample_proportional(self, batch_size):
        res = []
        for _ in range(batch_size):
            # TODO(szymon): should we ensure no repeats?
            mass = random.random() * self._it_sum.sum(0,
                                                      len(self._storage) - 1)
            idx = self._it_sum.find_prefixsum_idx(mass)
            res.append(idx)
        return res

    def sample(self, batch_size, beta=0):
        """
        Sample a batch of experiences.

        compared to ReplayBuffer.sample
        it also returns importance weights and idxes
        of sampled experiences.

        :param batch_size: (int) How many transitions to sample.
        :param beta: (float) To what degree to use importance weights (0 - no corrections, 1 - full correction)
        :return:
            - obs_batch: (np.ndarray) batch of observations
            - act_batch: (numpy float) batch of actions executed given obs_batch
            - rew_batch: (numpy float) rewards received as results of executing act_batch
            - next_obs_batch: (np.ndarray) next set of observations seen after executing act_batch
            - done_mask: (numpy bool) done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode
                and 0 otherwise.
            - weights: (numpy float) Array of shape (batch_size,) and dtype np.float32 denoting importance weight of
                each sampled transition
            - idxes: (numpy int) Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences
        """
        assert beta > 0

        idxes = self._sample_proportional(batch_size)

        weights = []
        p_min = self._it_min.min() / self._it_sum.sum()
        max_weight = (p_min * len(self._storage))**(-beta)

        for idx in idxes:
            p_sample = self._it_sum[idx] / self._it_sum.sum()
            weight = (p_sample * len(self._storage))**(-beta)
            weights.append(weight / max_weight)
        weights = np.array(weights)
        encoded_sample = self._encode_sample(idxes)
        return tuple(list(encoded_sample) + [weights, idxes])

    def update_priorities(self, idxes, priorities):
        """
        Update priorities of sampled transitions.

        sets priority of transition at index idxes[i] in buffer
        to priorities[i].

        :param idxes: ([int]) List of idxes of sampled transitions
        :param priorities: ([float]) List of updated priorities corresponding to transitions at the sampled idxes
            denoted by variable `idxes`.
        """
        assert len(idxes) == len(priorities)
        for idx, priority in zip(idxes, priorities):
            assert priority > 0
            assert 0 <= idx < len(self._storage)
            self._it_sum[idx] = priority**self._alpha
            self._it_min[idx] = priority**self._alpha

            self._max_priority = max(self._max_priority, priority)
예제 #10
0
    def __init__(self,
                 buffer_size,
                 state_dim,
                 action_shape,
                 obs_space,
                 q_func,
                 repr_func,
                 obs_ph,
                 action_ph,
                 sess,
                 gamma=0.99,
                 alpha=0.6,
                 max_step=1000):
        buffer_size = int(buffer_size)
        self.state_dim = state_dim
        self.capacity = buffer_size
        self.curr_capacity = 0
        self.pointer = 0
        self.obs_space = obs_space
        self.action_shape = action_shape
        self.max_step = max_step

        self.query_buffer = np.zeros((buffer_size, state_dim))
        self._q_values = -np.inf * np.ones(buffer_size + 1)
        self.returns = -np.inf * np.ones(buffer_size + 1)
        self.replay_buffer = np.empty((buffer_size, ) + obs_space.shape,
                                      np.float32)
        self.action_buffer = np.empty((buffer_size, ) + action_shape,
                                      np.float32)
        self.reward_buffer = np.empty((buffer_size, ), np.float32)
        self.steps = np.empty((buffer_size, ), np.int)
        self.done_buffer = np.empty((buffer_size, ), np.bool)
        self.truly_done_buffer = np.empty((buffer_size, ), np.bool)
        self.next_id = -1 * np.ones(buffer_size)
        self.prev_id = [[] for _ in range(buffer_size)]
        self.ddpg_q_values = -np.inf * np.ones(buffer_size)
        self.contra_count = np.ones((buffer_size, ))
        self.lru = np.zeros(buffer_size)
        self.time = 0
        self.gamma = gamma
        # self.hashes = dict()
        self.reward_mean = None
        self.min_return = 0
        self.end_points = []
        assert alpha > 0
        self._alpha = alpha
        self.beta_set = [-1]
        self.beta_coef = [1.]
        it_capacity = 1
        while it_capacity < buffer_size:
            it_capacity *= 2

        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)
        self._max_priority = 1.0

        self.q_func = q_func
        self.repr_func = repr_func
        self.obs_ph = obs_ph
        self.action_ph = action_ph
        self.sess = sess
예제 #11
0
class EpisodicMemory(object):
    def __init__(self,
                 buffer_size,
                 state_dim,
                 action_shape,
                 obs_space,
                 q_func,
                 repr_func,
                 obs_ph,
                 action_ph,
                 sess,
                 gamma=0.99,
                 alpha=0.6,
                 max_step=1000):
        buffer_size = int(buffer_size)
        self.state_dim = state_dim
        self.capacity = buffer_size
        self.curr_capacity = 0
        self.pointer = 0
        self.obs_space = obs_space
        self.action_shape = action_shape
        self.max_step = max_step

        self.query_buffer = np.zeros((buffer_size, state_dim))
        self._q_values = -np.inf * np.ones(buffer_size + 1)
        self.returns = -np.inf * np.ones(buffer_size + 1)
        self.replay_buffer = np.empty((buffer_size, ) + obs_space.shape,
                                      np.float32)
        self.action_buffer = np.empty((buffer_size, ) + action_shape,
                                      np.float32)
        self.reward_buffer = np.empty((buffer_size, ), np.float32)
        self.steps = np.empty((buffer_size, ), np.int)
        self.done_buffer = np.empty((buffer_size, ), np.bool)
        self.truly_done_buffer = np.empty((buffer_size, ), np.bool)
        self.next_id = -1 * np.ones(buffer_size)
        self.prev_id = [[] for _ in range(buffer_size)]
        self.ddpg_q_values = -np.inf * np.ones(buffer_size)
        self.contra_count = np.ones((buffer_size, ))
        self.lru = np.zeros(buffer_size)
        self.time = 0
        self.gamma = gamma
        # self.hashes = dict()
        self.reward_mean = None
        self.min_return = 0
        self.end_points = []
        assert alpha > 0
        self._alpha = alpha
        self.beta_set = [-1]
        self.beta_coef = [1.]
        it_capacity = 1
        while it_capacity < buffer_size:
            it_capacity *= 2

        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)
        self._max_priority = 1.0

        self.q_func = q_func
        self.repr_func = repr_func
        self.obs_ph = obs_ph
        self.action_ph = action_ph
        self.sess = sess

    def clean(self):
        buffer_size, state_dim, obs_space, action_shape = self.capacity, self.state_dim, self.obs_space, self.action_shape
        self.curr_capacity = 0
        self.pointer = 0

        self.query_buffer = np.zeros((buffer_size, state_dim))
        self._q_values = -np.inf * np.ones(buffer_size + 1)
        self.returns = -np.inf * np.ones(buffer_size + 1)
        self.replay_buffer = np.empty((buffer_size, ) + obs_space.shape,
                                      np.float32)
        self.action_buffer = np.empty((buffer_size, ) + action_shape,
                                      np.float32)
        self.reward_buffer = np.empty((buffer_size, ), np.float32)
        self.steps = np.empty((buffer_size, ), np.int)
        self.done_buffer = np.empty((buffer_size, ), np.bool)
        self.truly_done_buffer = np.empty((buffer_size, ), np.bool)
        self.next_id = -1 * np.ones(buffer_size)
        self.prev_id = [[] for _ in range(buffer_size)]
        self.ddpg_q_values = -np.inf * np.ones(buffer_size)
        self.contra_count = np.ones((buffer_size, ))
        self.lru = np.zeros(buffer_size)
        self.time = 0

        it_capacity = 1
        while it_capacity < buffer_size:
            it_capacity *= 2

        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)
        self._max_priority = 1.0
        self.end_points = []

    @property
    def q_values(self):
        return self._q_values

    def squeeze(self, obses):
        return np.array([(obs - self.obs_space.low) /
                         (self.obs_space.high - self.obs_space.low)
                         for obs in obses])

    def unsqueeze(self, obses):
        return np.array([
            obs * (self.obs_space.high - self.obs_space.low) +
            self.obs_space.low for obs in obses
        ])

    def save(self, filedir):
        save_dict = {
            "query_buffer": self.query_buffer,
            "returns": self.returns,
            "replay_buffer": self.replay_buffer,
            "reward_buffer": self.reward_buffer,
            "truly_done_buffer": self.truly_done_buffer,
            "next_id": self.next_id,
            "prev_id": self.prev_id,
            "gamma": self.gamma,
            "_q_values": self._q_values,
            "done_buffer": self.done_buffer,
            "curr_capacity": self.curr_capacity,
            "capacity": self.capacity
        }
        with open(os.path.join(filedir, "episodic_memory.pkl"),
                  "wb") as memory_file:
            pkl.dump(save_dict, memory_file)

    def add(self, obs, action, state, sampled_return, next_id=-1):

        index = self.pointer
        self.pointer = (self.pointer + 1) % self.capacity

        if self.curr_capacity >= self.capacity:
            # Clean up old entry
            if index in self.end_points:
                self.end_points.remove(index)
            self.prev_id[index] = []
            self.next_id[index] = -1
            self.q_values[index] = -np.inf
        else:
            self.curr_capacity = min(self.capacity, self.curr_capacity + 1)
        # Store new entry
        self.replay_buffer[index] = obs
        self.action_buffer[index] = action
        if state is not None:
            self.query_buffer[index] = state
        self.q_values[index] = sampled_return
        self.returns[index] = sampled_return
        self.lru[index] = self.time

        self._it_sum[index] = self._max_priority**self._alpha
        self._it_min[index] = self._max_priority**self._alpha
        if next_id >= 0:
            self.next_id[index] = next_id
            if index not in self.prev_id[next_id]:
                self.prev_id[next_id].append(index)
        self.time += 0.01

        return index

    def update_priority(self, idxes, priorities):
        # priorities = 1 / np.sqrt(self.contra_count[:self.curr_capacity])
        # priorities = priorities / np.max(priorities)
        assert len(idxes) == len(priorities)
        for idx, priority in zip(idxes, priorities):
            priority = max(priority, 1e-6)
            # assert priority > 0
            assert 0 <= idx < self.capacity
            self._it_sum[idx] = priority**self._alpha
            self._it_min[idx] = priority**self._alpha

            self._max_priority = max(self._max_priority, priority)

    def sample_neg_keys(self, avoids, batch_size):
        # sample negative keys
        assert batch_size + len(
            avoids
        ) <= self.capacity, "can't sample that much neg samples from episodic memory!"
        places = []
        while len(places) < batch_size:
            ind = np.random.randint(0, self.curr_capacity)
            if ind not in places:
                places.append(ind)
        return places

    def compute_approximate_return(self, obses, actions=None):
        return np.min(np.array(
            self.sess.run(self.q_func, feed_dict={self.obs_ph: obses})),
                      axis=0)

    def compute_statistics(self, batch_size=1024):
        estimated_qs = []
        for i in range(math.ceil(self.curr_capacity / batch_size)):
            start = i * batch_size
            end = min((i + 1) * batch_size, self.curr_capacity)
            obses = self.replay_buffer[start:end]
            actions = None
            estimated_qs.append(
                self.compute_approximate_return(obses, actions).reshape(-1))
        estimated_qs = np.concatenate(estimated_qs)
        diff = estimated_qs - self.q_values[:self.curr_capacity]
        return np.min(diff), np.mean(diff), np.max(diff)

    def retrieve_trajectories(self):
        trajs = []
        for e in self.end_points:
            traj = []
            prev = e
            while prev is not None:
                traj.append(prev)
                try:
                    prev = self.prev_id[prev][0]
                    # print(e,prev)
                except IndexError:
                    prev = None
            # print(np.array(traj))
            trajs.append(np.array(traj))
        return trajs

    def update_memory(self, q_base=0, use_knn=False, beta=-1):

        trajs = self.retrieve_trajectories()
        for traj in trajs:
            # print(np.array(traj))
            approximate_qs = self.compute_approximate_return(
                self.replay_buffer[traj], self.action_buffer[traj])
            approximate_qs = approximate_qs.reshape(-1)
            approximate_qs = np.insert(approximate_qs, 0, 0)

            self.q_values[traj] = 0
            Rtn = -1e10 if beta < 0 else 0
            for i, s in enumerate(traj):
                approximate_q = self.reward_buffer[s] + \
                                self.gamma * (1 - self.truly_done_buffer[s]) * (approximate_qs[i] - q_base)
                Rtn = self.reward_buffer[s] + self.gamma * (
                    1 - self.truly_done_buffer[s]) * Rtn
                if beta < 0:
                    Rtn = max(Rtn, approximate_q)
                else:
                    Rtn = beta * Rtn + (1 - beta) * approximate_q
                self.q_values[s] = Rtn

    def update_sequence_with_qs(self, sequence):
        # print(sequence)
        next_id = -1
        Rtd = 0
        for obs, a, z, q_t, r, truly_done, done in reversed(sequence):
            # print(np.mean(z))
            if truly_done:
                Rtd = r
            else:
                Rtd = self.gamma * Rtd + r
            current_id = self.add(obs, a, z, Rtd, next_id)

            if done:
                self.end_points.append(current_id)
            self.replay_buffer[current_id] = obs
            self.reward_buffer[current_id] = r
            self.truly_done_buffer[current_id] = truly_done
            self.done_buffer[current_id] = done
            next_id = int(current_id)
        # self.update_priority()
        return

    def sample_negative(self, batch_size, batch_idxs, batch_idxs_next,
                        batch_idx_pre):
        neg_batch_idxs = []
        i = 0
        while i < batch_size:
            neg_idx = np.random.randint(0, self.curr_capacity - 2)
            if neg_idx != batch_idxs[i] and neg_idx != batch_idxs_next[
                    i] and neg_idx not in batch_idx_pre[i]:
                neg_batch_idxs.append(neg_idx)
                i += 1
        neg_batch_idxs = np.array(neg_batch_idxs)
        return neg_batch_idxs, self.replay_buffer[neg_batch_idxs]

    @staticmethod
    def switch_first_half(obs0, obs1, batch_size):
        tmp = copy.copy(obs0[:batch_size // 2, ...])
        obs0[:batch_size // 2, ...] = obs1[:batch_size // 2, ...]
        obs1[:batch_size // 2, ...] = tmp
        return obs0, obs1

    def sample(self, batch_size, mix=False, priority=False):
        # Draw such that we always have a proceeding element
        if self.curr_capacity < batch_size + len(self.end_points):
            return None
        # if priority:
        #     self.update_priority()
        batch_idxs = []
        batch_idxs_next = []
        count = 0
        while len(batch_idxs) < batch_size:
            if priority:
                mass = random.random() * self._it_sum.sum(
                    0, self.curr_capacity)
                rnd_idx = self._it_sum.find_prefixsum_idx(mass)
            else:
                rnd_idx = np.random.randint(0, self.curr_capacity)
            count += 1
            assert count < 1e8
            if self.next_id[rnd_idx] == -1:
                continue
                # be careful !!!!!! I use random id because in our implementation obs1 is never used
                # if len(self.prev_id[rnd_idx]) > 0:
                #     batch_idxs_next.append(self.prev_id[rnd_idx][0])
                # else:
                #     batch_idxs_next.append(0)
            else:
                batch_idxs_next.append(self.next_id[rnd_idx])
                batch_idxs.append(rnd_idx)

        batch_idxs = np.array(batch_idxs).astype(np.int)
        batch_idxs_next = np.array(batch_idxs_next).astype(np.int)
        # batch_idx_pre = [self.prev_id[id] for id in batch_idxs]

        obs0_batch = self.replay_buffer[batch_idxs]
        obs1_batch = self.replay_buffer[batch_idxs_next]
        # batch_idxs_neg, obs2_batch = self.sample_negative(batch_size, batch_idxs, batch_idxs_next, batch_idx_pre)
        action_batch = self.action_buffer[batch_idxs]
        action1_batch = self.action_buffer[batch_idxs_next]
        # action2_batch = self.action_buffer[batch_idxs_neg]
        reward_batch = self.reward_buffer[batch_idxs]
        terminal1_batch = self.done_buffer[batch_idxs]
        q_batch = self.q_values[batch_idxs]
        return_batch = self.returns[batch_idxs]

        if mix:
            obs0_batch, obs1_batch = self.switch_first_half(
                obs0_batch, obs1_batch, batch_size)
        if priority:
            self.contra_count[batch_idxs] += 1
            self.contra_count[batch_idxs_next] += 1

        result = {
            'index0':
            array_min2d(batch_idxs),
            'index1':
            array_min2d(batch_idxs_next),
            # 'index2': array_min2d(batch_idxs_neg),
            'obs0':
            array_min2d(obs0_batch),
            'obs1':
            array_min2d(obs1_batch),
            # 'obs2': array_min2d(obs2_batch),
            'rewards':
            array_min2d(reward_batch),
            'actions':
            array_min2d(action_batch),
            'actions1':
            array_min2d(action1_batch),
            # 'actions2': array_min2d(action2_batch),
            'count':
            array_min2d(self.contra_count[batch_idxs] +
                        self.contra_count[batch_idxs_next]),
            'terminals1':
            array_min2d(terminal1_batch),
            'return':
            array_min2d(q_batch),
            'true_return':
            array_min2d(return_batch),
        }
        return result

    def plot(self):
        X = self.replay_buffer[:self.curr_capacity]
        model = TSNE()
        low_dim_data = model.fit_transform(X)
        plt.scatter(low_dim_data[:, 0], low_dim_data[:, 1])
        plt.show()
예제 #12
0
class TwoWayPrioritizedReplayBuffer(ReplayBuffer):
    def __init__(self, size, alpha):
        super(PrioritizedReplayBuffer, self).__init__(size)
        assert alpha >= 0
        self._alpha = alpha

        it_capacity = 1
        while it_capacity < size:
            it_capacity *= 2

        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)
        self._max_priority = 1.0

    def add(self, obs_t, action, reward, obs_tp1, done):
        idx = self._next_idx
        super().add(obs_t, action, reward, obs_tp1, done)
        self._it_sum[idx] = self._max_priority**self._alpha
        self._it_min[idx] = self._max_priority**self._alpha

    def extend(self, obs_t, action, reward, obs_tp1, done):
        idx = self._next_idx
        super().extend(obs_t, action, reward, obs_tp1, done)
        while idx != self._next_idx:
            self._it_sum[idx] = self._max_priority**self._alpha
            self._it_min[idx] = self._max_priority**self._alpha
            idx = (idx + 1) % self._maxsize

    def _sample_proportional(self, batch_size):
        mass = []
        total = self._it_sum.sum(0, len(self._storage) - 1)
        # TODO(szymon): should we ensure no repeats?
        mass = np.random.random(size=batch_size) * total
        idx = self._it_sum.find_prefixsum_idx(mass)
        return idx

    def sample(self,
               batch_size: int,
               beta: float = 0,
               env: Optional[VecNormalize] = None):
        assert beta > 0

        idxes = self._sample_proportional(batch_size)
        weights = []
        p_min = self._it_min.min() / self._it_sum.sum()
        max_weight = (p_min * len(self._storage))**(-beta)
        p_sample = self._it_sum[idxes] / self._it_sum.sum()
        weights = (p_sample * len(self._storage))**(-beta) / max_weight
        encoded_sample = self._encode_sample(idxes, env=env)
        return tuple(list(encoded_sample) + [weights, idxes])

    def _sample_invproportional(self, batch_size):
        mass = []
        total = self._it_sum.sum(len(self._storage) - 1, 0)
        # TODO(szymon): should we ensure no repeats?
        mass = np.random.random(size=batch_size) * total
        idx = self._it_sum.find_prefixsum_idx(mass)
        return idx

    def invsample(self,
                  batch_size: int,
                  beta: float = 0,
                  env: Optional[VecNormalize] = None):
        assert beta > 0

        idxes = self._sample_invproportional(batch_size)
        encoded_sample = self._encode_sample(idxes, env=env)
        return list(encoded_sample)

    def update_priorities(self, idxes, priorities):
        assert len(idxes) == len(priorities)
        assert np.min(priorities) > 0
        assert np.min(idxes) >= 0
        assert np.max(idxes) < len(self.storage)
        self._it_sum[idxes] = priorities**self._alpha
        self._it_min[idxes] = priorities**self._alpha

        self._max_priority = max(self._max_priority, np.max(priorities))
예제 #13
0
class PrioritizedReplayBuffer(ReplayBuffer):
    def __init__(self, size, alpha):
        """
        Create Prioritized Replay buffer.

        See Also ReplayBuffer.__init__

        :param size: (int) Max number of transitions to store in the buffer. When the buffer overflows the old memories
            are dropped.
        :param alpha: (float) how much prioritization is used (0 - no prioritization, 1 - full prioritization)
        """
        super(PrioritizedReplayBuffer, self).__init__(size)
        assert alpha >= 0
        self._alpha = alpha

        it_capacity = 1
        while it_capacity < size:
            it_capacity *= 2

        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)
        self._max_priority = 1.0

    def add(self, obs_t, action, reward, obs_tp1, done):
        """
        add a new transition to the buffer

        :param obs_t: (Any) the last observation
        :param action: ([float]) the action
        :param reward: (float) the reward of the transition
        :param obs_tp1: (Any) the current observation
        :param done: (bool) is the episode done
        """
        idx = self._next_idx
        super().add(obs_t, action, reward, obs_tp1, done)
        self._it_sum[idx] = self._max_priority**self._alpha
        self._it_min[idx] = self._max_priority**self._alpha

    def extend(self, obs_t, action, reward, obs_tp1, done):
        """
        add a new batch of transitions to the buffer

        :param obs_t: (Union[Tuple[Union[np.ndarray, int]], np.ndarray]) the last batch of observations
        :param action: (Union[Tuple[Union[np.ndarray, int]]], np.ndarray]) the batch of actions
        :param reward: (Union[Tuple[float], np.ndarray]) the batch of the rewards of the transition
        :param obs_tp1: (Union[Tuple[Union[np.ndarray, int]], np.ndarray]) the current batch of observations
        :param done: (Union[Tuple[bool], np.ndarray]) terminal status of the batch

        Note: uses the same names as .add to keep compatibility with named argument passing
            but expects iterables and arrays with more than 1 dimensions
        """
        idx = self._next_idx
        super().extend(obs_t, action, reward, obs_tp1, done)
        while idx != self._next_idx:
            self._it_sum[idx] = self._max_priority**self._alpha
            self._it_min[idx] = self._max_priority**self._alpha
            idx = (idx + 1) % self._maxsize

    def _sample_proportional(self, batch_size):
        mass = []
        total = self._it_sum.sum(0, len(self._storage) - 1)
        # TODO(szymon): should we ensure no repeats?
        mass = np.random.random(size=batch_size) * total
        idx = self._it_sum.find_prefixsum_idx(mass)
        return idx

    def sample(self,
               batch_size: int,
               beta: float = 0,
               env: Optional[VecNormalize] = None):
        """
        Sample a batch of experiences.

        compared to ReplayBuffer.sample
        it also returns importance weights and idxes
        of sampled experiences.

        :param batch_size: (int) How many transitions to sample.
        :param beta: (float) To what degree to use importance weights (0 - no corrections, 1 - full correction)
        :param env: (Optional[VecNormalize]) associated gym VecEnv
            to normalize the observations/rewards when sampling
        :return:
            - obs_batch: (np.ndarray) batch of observations
            - act_batch: (numpy float) batch of actions executed given obs_batch
            - rew_batch: (numpy float) rewards received as results of executing act_batch
            - next_obs_batch: (np.ndarray) next set of observations seen after executing act_batch
            - done_mask: (numpy bool) done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode
                and 0 otherwise.
            - weights: (numpy float) Array of shape (batch_size,) and dtype np.float32 denoting importance weight of
                each sampled transition
            - idxes: (numpy int) Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences
        """
        assert beta > 0

        idxes = self._sample_proportional(batch_size)
        weights = []
        p_min = self._it_min.min() / self._it_sum.sum()
        max_weight = (p_min * len(self._storage))**(-beta)
        p_sample = self._it_sum[idxes] / self._it_sum.sum()
        weights = (p_sample * len(self._storage))**(-beta) / max_weight
        encoded_sample = self._encode_sample(idxes, env=env)
        return tuple(list(encoded_sample) + [weights, idxes])

    def update_priorities(self, idxes, priorities):
        """
        Update priorities of sampled transitions.

        sets priority of transition at index idxes[i] in buffer
        to priorities[i].

        :param idxes: ([int]) List of idxes of sampled transitions
        :param priorities: ([float]) List of updated priorities corresponding to transitions at the sampled idxes
            denoted by variable `idxes`.
        """
        assert len(idxes) == len(priorities)
        assert np.min(priorities) > 0
        assert np.min(idxes) >= 0
        assert np.max(idxes) < len(self.storage)
        self._it_sum[idxes] = priorities**self._alpha
        self._it_min[idxes] = priorities**self._alpha

        self._max_priority = max(self._max_priority, np.max(priorities))
예제 #14
0
class PrioritizedReplayBuffer(ReplayBuffer):
    def __init__(self, size, alpha):
        """Create Prioritized Replay buffer.
    Parameters
    ----------
    size: int
        Max number of transitions to store in the buffer. When the buffer
        overflows the old memories are dropped.
    alpha: float
        how much prioritization is used
        (0 - no prioritization, 1 - full prioritization)
    See Also
    --------
    ReplayBuffer.__init__
    """
        super(PrioritizedReplayBuffer, self).__init__(size)
        assert alpha > 0
        self._alpha = alpha

        it_capacity = 1
        while it_capacity < size:
            it_capacity *= 2

        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)
        self._max_priority = 1.0

    def add(self, *args, **kwargs):
        """See ReplayBuffer.store_effect"""
        idx = self._next_idx
        super().add(*args, **kwargs)
        self._it_sum[idx] = self._max_priority**self._alpha
        self._it_min[idx] = self._max_priority**self._alpha

    def _sample_proportional(self, batch_size):
        res = []
        for _ in range(batch_size):
            # TODO(szymon): should we ensure no repeats?
            mass = random.random() * self._it_sum.sum(0,
                                                      len(self._storage) - 1)
            idx = self._it_sum.find_prefixsum_idx(mass)
            res.append(idx)
        return res

    def sample(self, batch_size, beta):
        """Sample a batch of experiences.
    compared to ReplayBuffer.sample
    it also returns importance weights and idxes
    of sampled experiences.
    Parameters
    ----------
    batch_size: int
        How many transitions to sample.
    beta: float
        To what degree to use importance weights
        (0 - no corrections, 1 - full correction)
    Returns
    -------
    obs_batch: np.array
        batch of observations
    act_batch: np.array
        batch of actions executed given obs_batch
    R_batch: np.array
        returns received as results of executing act_batch
    weights: np.array
        Array of shape (batch_size,) and dtype np.float32
        denoting importance weight of each sampled transition
    idxes: np.array
        Array of shape (batch_size,) and dtype np.int32
        idexes in buffer of sampled experiences
    """

        idxes = self._sample_proportional(batch_size)

        if beta > 0:
            weights = []
            p_min = self._it_min.min() / self._it_sum.sum()
            max_weight = (p_min * len(self._storage))**(-beta)

            for idx in idxes:
                p_sample = self._it_sum[idx] / self._it_sum.sum()
                weight = (p_sample * len(self._storage))**(-beta)
                weights.append(weight / max_weight)
            weights = np.array(weights)
        else:
            weights = np.ones_like(idxes, dtype=np.float32)
        encoded_sample = self._encode_sample(idxes)
        return tuple(list(encoded_sample) + [weights, idxes])

    def update_priorities(self, idxes, priorities):
        """Update priorities of sampled transitions.
    sets priority of transition at index idxes[i] in buffer
    to priorities[i].
    Parameters
    ----------
    idxes: [int]
        List of idxes of sampled transitions
    priorities: [float]
        List of updated priorities corresponding to
        transitions at the sampled idxes denoted by
        variable `idxes`.
    """
        assert len(idxes) == len(priorities)
        for idx, priority in zip(idxes, priorities):
            priority = max(priority, 1e-6)
            assert priority > 0
            assert 0 <= idx < len(self._storage)
            self._it_sum[idx] = priority**self._alpha
            self._it_min[idx] = priority**self._alpha

            self._max_priority = max(self._max_priority, priority)
예제 #15
0
def test_prefixsum_idx():
    """
    test Segment Tree data structure
    """
    tree = SumSegmentTree(4)

    tree[2] = 1.0
    tree[3] = 3.0

    assert tree.find_prefixsum_idx(0.0) == 2
    assert tree.find_prefixsum_idx(0.5) == 2
    assert tree.find_prefixsum_idx(0.99) == 2
    assert tree.find_prefixsum_idx(1.01) == 3
    assert tree.find_prefixsum_idx(3.00) == 3
    assert tree.find_prefixsum_idx(4.00) == 3
    assert np.all(tree.find_prefixsum_idx([0.0, 0.5, 0.99, 1.01, 3.00, 4.00]) == [2, 2, 2, 3, 3, 3])

    tree = SumSegmentTree(4)

    tree[np.array([2, 3])] = [1.0, 3.0]

    assert tree.find_prefixsum_idx(0.0) == 2
    assert tree.find_prefixsum_idx(0.5) == 2
    assert tree.find_prefixsum_idx(0.99) == 2
    assert tree.find_prefixsum_idx(1.01) == 3
    assert tree.find_prefixsum_idx(3.00) == 3
    assert tree.find_prefixsum_idx(4.00) == 3
    assert np.all(tree.find_prefixsum_idx([0.0, 0.5, 0.99, 1.01, 3.00, 4.00]) == [2, 2, 2, 3, 3, 3])
class PrioritizedMemory(Memory):
    def __init__(self,
                 capacity,
                 alpha,
                 transition_small_epsilon=1e-6,
                 demo_epsilon=0.2):
        super(PrioritizedMemory, self).__init__(capacity)
        assert alpha > 0
        self._alpha = alpha
        self._transition_small_epsilon = transition_small_epsilon
        self._demo_epsilon = demo_epsilon
        it_capacity = 1
        while it_capacity < self.capacity:
            it_capacity *= 2  # Size must be power of 2
        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)
        self._max_priority = 100.0
        self.help = []

    def append(self, obs0, obs1, f_s0, f_s1, action, reward, terminal1):
        idx = self._next_idx
        if not super(PrioritizedMemory, self).append_(obs0=obs0,
                                                      obs1=obs1,
                                                      f_s0=f_s0,
                                                      f_s1=f_s1,
                                                      actions=action,
                                                      rewards=reward,
                                                      terminal1=terminal1):
            return
        # 新加入的'transition'设置为最大优先级, 确保每一个'transition'都至少被采样一次
        self._it_sum[idx] = self._max_priority
        self._it_min[idx] = self._max_priority

    def append_demo(self, obs0, obs1, f_s0, f_s1, action, reward, terminal1):
        idx = self._next_idx
        if not super(PrioritizedMemory, self).append_(obs0=obs0,
                                                      obs1=obs1,
                                                      f_s0=f_s0,
                                                      f_s1=f_s1,
                                                      actions=action,
                                                      rewards=reward,
                                                      terminal1=terminal1,
                                                      count=False):
            return
        self._it_sum[idx] = self._max_priority
        self._it_min[idx] = self._max_priority
        self.num_demonstrations += 1

    def _sample_proportional(self, batch_size, pretrain):
        res = []
        if pretrain:
            res = np.random.random_integers(low=0,
                                            high=self.nb_entries - 1,
                                            size=batch_size)
            return res
        for _ in range(batch_size):
            while True:
                mass = np.random.uniform(
                    0, self._it_sum.sum(0,
                                        len(self.storage) - 1))
                idx = self._it_sum.find_prefixsum_idx(mass)
                if idx not in res:
                    res.append(idx)
                    break
        return res

    def sample(self, batch_size, beta, pretrain=False):
        idxes = self._sample_proportional(batch_size, pretrain)
        # demos is a bool
        demos = [i < self.num_demonstrations for i in idxes]
        weights = []
        p_sum = self._it_sum.sum()
        # 算重要性采样权重 weights
        for idx in idxes:
            p_sample = self._it_sum[idx] / p_sum
            weight = ((1.0 / p_sample) * (1.0 / len(self.storage)))**beta
            weights.append(weight)
        weights = np.array(weights) / np.max(weights)
        encoded_sample = self._get_batches_for_idxes(idxes)
        encoded_sample['weights'] = array_min2d(weights)
        encoded_sample['idxes'] = array_min2d(idxes)
        encoded_sample['demos'] = array_min2d(demos)
        return encoded_sample

    def sample_rollout(self, batch_size, nsteps, beta, gamma, pretrain=False):
        batches = self.sample(batch_size, beta, pretrain)
        n_step_batches = {
            storable_element: []
            for storable_element in self.storable_elements
        }
        n_step_batches["step_reached"] = []
        idxes = batches["idxes"]
        for idx in idxes:
            local_idxes = list(
                range(int(idx), int(min(idx + nsteps, len(self)))))
            transitions = self._get_batches_for_idxes(local_idxes)
            summed_reward = 0
            count = 0
            terminal = 0.0
            terminals = transitions['terminals1']
            r = transitions['rewards']
            for i in range(len(r)):
                summed_reward += (gamma**i) * r[i][0]
                count = i
                if terminals[i]:
                    terminal = 1.0
                    break

            n_step_batches["step_reached"].append(count)
            n_step_batches["obs1"].append(transitions["obs1"][count])
            n_step_batches["f_s1"].append(transitions["f_s1"][count])
            n_step_batches["terminals1"].append(terminal)
            n_step_batches["rewards"].append(summed_reward)
            n_step_batches["actions"].append(transitions["actions"][0])
            n_step_batches['demos'] = batches['demos']
        n_step_batches['weights'] = batches['weights']
        n_step_batches['idxes'] = idxes

        n_step_batches = {k: array_min2d(v) for k, v in n_step_batches.items()}

        return batches, n_step_batches, sum(batches['demos']) / batch_size

    def update_priorities(self, idxes, td_errors, actor_losses=0.0):
        priorities = td_errors + \
            (actor_losses ** 2) + self._transition_small_epsilon
        for i in range(len(priorities)):
            if idxes[i] < self.num_demonstrations:
                priorities[i] += np.max(priorities) * self._demo_epsilon
        assert len(idxes) == len(priorities)
        for idx, priority in zip(idxes, priorities):
            assert priority > 0
            assert 0 <= idx < len(self.storage)
            idx = int(idx)
            self._it_sum[idx] = priority**self._alpha
            self._it_min[idx] = priority**self._alpha
            self._max_priority = max(self._max_priority, priority**self._alpha)
            self.help.append(priority**self._alpha)