class ReplayBuffer:
    """Fixed-size buffer to store experience tuples."""
    def __init__(self,
                 buffer_size,
                 batch_size,
                 td_eps,
                 seed,
                 p_replay_alpha,
                 reward_scale=False,
                 error_clip=False,
                 error_max=1.0,
                 error_init=False,
                 use_tree=False,
                 err_init=1.0):
        """Initialize a ReplayBuffer object.

        Params
        ======
            buffer_size (int): maximum size of buffer
            batch_size (int): size of each training batch
            td_eps: (float): to avoid zero td_error
            p_replay_alpha (float): discount factor for priority sampling
            reward_scale (flag): to scale reward down by 10
            error_clip (flag): max error to 1
            seed (int): random seed
        """
        self.useTree = use_tree
        self.memory = deque(maxlen=buffer_size)
        self.tree = SumTree(buffer_size)  #create tree instance
        self.batch_size = batch_size
        self.buffer_size = buffer_size
        self.td_eps = td_eps
        self.experience = namedtuple("Experience",
                                     field_names=[
                                         "state", "action", "reward",
                                         "next_state", "done", "td_error"
                                     ])
        self.seed = random.seed(seed)
        self.p_replay_alpha = p_replay_alpha
        self.reward_scale = reward_scale
        self.error_clip = error_clip
        self.error_init = error_init
        self.error_max = error_max

        self.memory_index = np.zeros([self.buffer_size,
                                      1])  #for quicker calculation
        self.memory_pointer = 0

    def add(self, state, action, reward, next_state, done, td_error):
        """Add a new experience to memory.
        td_error: abs value
        """

        #reward clipping
        if self.reward_scale:
            reward = reward / 10.0  #scale reward by factor of 10

        #error clipping
        if self.error_clip:  #error clipping
            td_error = np.clip(td_error, -self.error_max, self.error_max)

        # apply alpha power
        td_error = (td_error**self.p_replay_alpha) + self.td_eps

        # make sure experience is at least visit once
        if self.error_init:
            td_mad = np.max(self.memory_index)
            if td_mad == 0:
                td_error = self.error_max
            else:
                td_error = td_mad

        e = self.experience(np.expand_dims(state, 0), action, reward,
                            np.expand_dims(next_state, 0), done, td_error)
        if self.useTree:
            self.tree.add(td_error,
                          e)  # update the td score and experience data
        else:
            self.memory.append(e)

        ### memory index ###
        if self.memory_pointer >= self.buffer_size:
            #self.memory_pointer = 0
            self.memory_index = np.roll(self.memory_index, -1)
            self.memory_index[-1] = td_error  #fifo
        else:
            self.memory_index[self.memory_pointer] = td_error
            self.memory_pointer += 1

    def update(self, td_updated, index):
        """
        update the td error values while restoring orders
        td_updated: abs value; np.array of shape 1,batch_size,1
        index: in case of tree, it is the leaf index
        """
        td_updated = td_updated.squeeze()  # (batch_size,)

        #error clipping
        if self.error_clip:  #error clipping
            td_updated = np.clip(td_updated, -ERROR_MAX, ERROR_MAX)

        # apply alpha power
        td_updated = (td_updated**self.p_replay_alpha) + self.td_eps

        ### checking memory and memory index are sync ###
        #tmp_memory = copy.deepcopy(self.memory)

        i = 0  #while loop
        while i < len(index):
            if self.useTree:
                #data_index = index[i]
                #tree_index = data_index + self.buffer_size - 1
                self.tree.update(index[i], td_updated[i])
            else:
                self.memory.rotate(
                    -index[i])  # move the target index to the front
                e = self.memory.popleft()

                td_i = td_updated[i].reshape(1, 1)

                e1 = self.experience(e.state, e.action, e.reward, e.next_state,
                                     e.done, td_i)

                self.memory.appendleft(e1)  #append the new update
                self.memory.rotate(index[i])  #restore the original order

                ### memory index ###
                self.memory_index[index[i]] = td_i

            i += 1  #increment

            # make sure its updated
            # assert(self.memory[index[i]].td_error == self.memory_index[index[i]])
            ### checking memory and memory index are sync ###
            #for i in range(len(self.memory)):
            #    assert(self.memory_index[i] == self.memory[i].td_error)
            #    if i in index:
            #        assert(td_updated[list(index).index(i)] == self.memory[i].td_error)
            #    else:
            #        print(self.memory[i].td_error)
            #        assert(tmp_memory[i].td_error == self.memory[i].td_error)

    def sample(self, p_replay_beta):
        """Sample a batch of experiences from memory."""
        l = len(self.memory)
        p_dist = (self.memory_index[:l] /
                  np.sum(self.memory_index[:l])).squeeze()

        assert (np.abs(np.sum(p_dist) - 1) < 1e-5)
        assert (len(p_dist) == l)

        # get sample of index from the p distribution
        sample_ind = np.random.choice(l, self.batch_size, p=p_dist)

        ### checking: make sure the rotation didnt screw up the memory ###
        #tmp_memory = copy.deepcopy(self.memory) #checking

        # get the selected experiences: avoid using mid list indexing
        es, ea, er, en, ed = [], [], [], [], []
        for i in sample_ind:
            self.memory.rotate(-i)
            e = copy.deepcopy(self.memory[0])
            es.append(e.state)
            ea.append(e.action)
            er.append(e.reward)
            en.append(e.next_state)
            ed.append(e.done)
            self.memory.rotate(i)

        ### checking: make sure the rotation didnt screw up the memory ###
        #for i in range(len(tmp_memory)):
        #    assert(tmp_memory[i].td_error == self.memory[i].td_error) #checking

        states = torch.from_numpy(np.vstack(es)).float().to(device)
        actions = torch.from_numpy(np.vstack(ea)).long().to(device)
        rewards = torch.from_numpy(np.vstack(er)).float().to(device)
        next_states = torch.from_numpy(np.vstack(en)).float().to(device)
        dones = torch.from_numpy(np.vstack(ed).astype(
            np.uint8)).float().to(device)

        # for weight update adjustment
        selected_td_p = p_dist[sample_ind]  #the prob of selected e

        ### checker: the mean of selected TD errors should be greater than
        ### checking: the mean of selected TD err are higher than memory average
        if p_replay_beta > 0:
            if np.mean(self.memory_index[sample_ind]) < np.mean(
                    self.memory_index[:l]):
                print(np.mean(self.memory_index[sample_ind]),
                      np.mean(self.memory_index[:l]))

        #weight = (np.array(selected_td_p) * l) ** -p_replay_beta
        #max_weight = (np.min(selected_td_p) * self.batch_size) ** -p_replay_beta

        weight = (1 / selected_td_p * 1 / l)**p_replay_beta
        weight = weight / np.max(weight)  #normalizer by max
        weight = torch.from_numpy(np.array(weight)).float().to(
            device)  #change form
        assert (weight.requires_grad == False)

        return (states, actions, rewards, next_states,
                dones), weight, sample_ind

    def sample_tree(self, p_replay_beta):
        # Create a sample array that will contains the minibatch
        e_s, e_a, e_r, e_n, e_d = [], [], [], [], []

        sample_ind = np.empty((self.batch_size, ), dtype=np.int32)
        sampled_td_score = np.empty((self.batch_size, 1))
        weight = np.empty((self.batch_size, 1))

        # Calculate the priority segment
        # Here, as explained in the paper, we divide the Range[0, ptotal] into n ranges
        td_score_segment = self.tree.total_td_score / self.batch_size  # priority segment

        i = 0  #use while loop
        while i < self.batch_size:
            """
            A value is uniformly sample from each range
            """
            a, b = td_score_segment * i, td_score_segment * (i + 1)
            value = np.random.uniform(a, b)
            """
            Experience that correspond to each value is retrieved
            """
            leaf_index, td_score, data = self.tree.get_leaf(value)

            #P(j)
            sampling_p = td_score / self.tree.total_td_score
            sampled_td_score[i, 0] = td_score

            #  IS = (1/N * 1/P(i))**b /max wi == (N*P(i))**-b  /max wi
            weight[i,
                   0] = (1 / self.buffer_size * 1 / sampling_p)**p_replay_beta

            sample_ind[i] = leaf_index

            e_s.append(data.state)
            e_a.append(data.action)
            e_r.append(data.reward)
            e_n.append(data.next_state)
            e_d.append(data.done)

            i += 1  # increment

        # Calculating the max_weight
        """
        p_min = np.min(self.tree.tree[-self.buffer_size:]) / self.tree.total_td_score
        if p_min == 0:
            p_min = self.td_eps # avoid div by zero
        max_weight = (1/p_min * 1/self.buffer_size) ** (p_replay_beta)
        """
        # apply max weigth adjustment
        max_weight = np.max(weight)
        weight = weight / max_weight

        #assert(np.mean(sampled_td_score) >= np.mean(self.tree.tree[-self.buffer_size:]))

        states = torch.from_numpy(np.vstack(e_s)).float().to(device)
        actions = torch.from_numpy(np.vstack(e_a)).long().to(device)
        rewards = torch.from_numpy(np.vstack(e_r)).float().to(device)
        next_states = torch.from_numpy(np.vstack(e_n)).float().to(device)
        dones = torch.from_numpy(np.vstack(e_d).astype(
            np.uint8)).float().to(device)

        weight = torch.from_numpy(weight).float().to(device)  #change form
        assert (weight.requires_grad == False)

        return (states, actions, rewards, next_states,
                dones), weight, sample_ind

    def __len__(self):
        """Return the current size of internal memory."""
        return len(self.memory)
예제 #2
0
class ReplayBuffer:
    def __init__(self, buffer_size, num_agents, state_size, action_size, use_PER=False):

        self.buffer_size = buffer_size
        self.use_PER = use_PER
        self.num_agents = num_agents
        self.state_size = state_size
        self.action_size = action_size

        if use_PER:
            self.tree = SumTree(buffer_size) #create tree instance
        else:
            self.memory = deque(maxlen=buffer_size)

        self.buffer_size = buffer_size
        self.leaves_count = 0

    def add_tree(self, data, td_default=1.0):
        """PER function. Add a new experience to memory. td_error: abs value"""
        td_max = np.max(self.tree.tree[-self.buffer_size:])
        if td_max == 0.0:
            td_max = td_default
        self.tree.add(td_max, data) #increase chance to be selected
        self.leaves_count = min(self.leaves_count+1,self.buffer_size)

    def add(self, data):
        """add into the buffer"""
        self.memory.append(data)

    def sample_tree(self, batch_size, p_replay_beta, td_eps=1e-4):
        """PER function. Segment piece wise sampling"""
        s_samp, a_samp, r_samp, d_samp, ns_samp = ([] for l in range(5))

        sample_ind = np.empty((batch_size,), dtype=np.int32)
        weight = np.empty((batch_size, 1))

        # create segments according to td score range
        td_score_segment = self.tree.total_td_score / batch_size

        for i in range(batch_size):
            # A value is uniformly sample from each range
            _start, _end = i * td_score_segment, (i+1) * td_score_segment
            value = np.random.uniform(_start, _end)

            # get the experience with the closest value in that segment
            leaf_index, td_score, data = self.tree.get_leaf(value)

            # the sampling prob for this sample across all tds
            sampling_p = td_score / self.tree.total_td_score

            # apply weight adjustment
            weight[i,0] = (1/sampling_p * 1/self.leaves_count)**p_replay_beta

            sample_ind[i] = leaf_index

            s_samp.append(data.states)
            a_samp.append(data.actions)
            r_samp.append(data.rewards)
            d_samp.append(data.dones)
            ns_samp.append(data.next_states)

        # Calculating the max_weight among entire memory
        #p_max = np.max(self.tree.tree[-self.buffer_size:]) / self.tree.total_td_score
        #if p_max == 0: p_max = td_eps # avoid div by zero
        #max_weight_t = (1/p_max * 1/self.leaves_count)**p_replay_beta
        #max_weight = np.max(weight)

        weight_n = toTorch(weight) #normalize weight /max_weight

        return (s_samp, a_samp, r_samp, d_samp, ns_samp, weight_n, sample_ind)


    def sample(self, batch_size):
        """sample from the buffer"""
        sample_ind = np.random.choice(len(self.memory), batch_size)

        s_samp, a_samp, r_samp, d_samp, ns_samp = ([] for l in range(5))

        i = 0
        while i < batch_size: #while loop is faster
            self.memory.rotate(-sample_ind[i])
            e = self.memory[0]
            s_samp.append(e.states)
            a_samp.append(e.actions)
            r_samp.append(e.rewards)
            d_samp.append(e.dones)
            ns_samp.append(e.next_states)
            self.memory.rotate(sample_ind[i])
            i += 1

        # last 2 values for functions compatibility with PER
        return (s_samp, a_samp, r_samp, d_samp, ns_samp, 1.0, [])

    def update_tree(self, td_updated, index, p_replay_alpha, td_eps=1e-4):
        """ PER function.
        update the td error values while restoring orders
        td_updated: abs value; np.array of shape 1,batch_size,1
        index: in case of tree, it is the leaf index
        """
        # apply alpha power
        td_updated = (td_updated.squeeze() ** p_replay_alpha) + td_eps

        for i in range(len(index)):
            self.tree.update(index[i], td_updated[i])

    def __len__(self):
        if not self.use_PER:
            return len(self.memory)
        else:
            return self.leaves_count