Beispiel #1
0
class Memory(object):
    def __init__(self, capacity, batch_size):
        self.capacity = capacity
        self.batch_size = batch_size
        self.tree = SumTree(capacity=capacity)

        self.alpha = 0.6
        self.beta = 0.4
        self.p_epsilon = 1e-4
        self.batch_size = 50

    def _get_priority(self, priorities):
        priorities += self.p_epsilon
        priorities = np.minimum(priorities, 1.0)
        priorities = np.power(priorities, self.alpha)
        return priorities

    def store(self, transition):
        max_p = np.max(self.tree.tree[-self.capacity:])
        if max_p == 0:
            max_p = 1.0
        self.tree.add(transition, max_p)

    def sample(self):
        avg_p = self.tree.total_p() / self.batch_size
        batch_tree_idx, batch_p, batch_data = [], [], []
        for i in range(self.batch_size):
            a, b = avg_p * i, avg_p * (i + 1)
            s = np.random.uniform(a, b)
            tree_idx, p, data = self.tree.sample(s)
            batch_tree_idx.append(tree_idx)
            batch_p.append(p)
            batch_data.append(data)
        batch_p /= self.tree.total_p()
        batch_weight = np.power(batch_p * self.capacity, -self.beta)
        batch_weight = batch_weight / max(batch_weight)
        batch_tree_idx, batch_data, batch_weight = map(
            np.array, [batch_tree_idx, batch_data, batch_weight])
        return batch_tree_idx, batch_data, batch_weight

    def update(self, tree_idx, priorities):
        priorities = self._get_priority(priorities)
        for index, p in zip(tree_idx, priorities):
            self.tree.update(index, p)
class PrioritizedReplayBuffer:
    """Fixed-size buffer to store experience tuples."""

    def __init__(self, buffer_size, seed):
        """Initialize a ReplayBuffer object.

        Params
        ======
            seed (int): random seed
        """
        self.memory = SumTree(buffer_size)
        self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
        self.seed = random.seed(seed)

        # epsilon: small amount to avoid zero priority
        # alpha: [0~1] determines how much prioritization is used. with 0, we would get the uniform case
        # beta: Controls importance-sampling compensation. fully compensates for the non-uniform probabilities
        #   when beta=1. The unbiased nature of the updates is most important near convergence at the end of
        #   training, so we define a schedule on the exponent beta that starts from initial value and reaches 1
        #   only at the end of learning.

        self.epsilon = 0.01
        self.alpha = 0.6
        
        beta_start = 0.4
        self.beta_end = 1.0
        self.beta = beta_start
        beta_increments = 200
        self.beta_increment = (self.beta_end - beta_start)/beta_increments

    def add(self, state, action, reward, next_state, done):
        """Add a new experience to memory."""
        experience = self.experience(state, action, reward, next_state, done)
        p = self.memory.max_p()
        if p == 0:
            p = 1.0
        self.memory.add(p=p, data=experience)

    def sample(self, n):
        """Randomly sample a batch of experiences from memory."""
        experiences = []
        indices = []
        priorities = []
        segment = self.memory.total_p() / n
        for i in range(n):
            a = segment * i
            b = segment * (i + 1)

            s = random.uniform(a, b)
            (idx, p, experience) = self.memory.get(s)
            experiences.append(experience)
            indices.append(idx)
            priorities.append(p)
        priorities = np.array(priorities, dtype=np.float64)
        indices = np.array(indices, dtype=np.int32)

        # print(f"priorities: {priorities}")
        probs = priorities / self.memory.total_p()
        # print(f"probs: {probs}")
        # importance-sampling (IS) weights
        w_is = (self.memory.capacity * probs) ** (-self.beta)
        # print(f"w_IS: {w_IS}")
        w_is_normalized = w_is/w_is.max()
        # print(f"w_IS_normalized: {w_IS_normalized}")
        # w_is_normalized = torch.from_numpy(w_is_normalized).float().to(self.device)
        
        return experiences, indices, w_is_normalized

    def update_errors(self, indices, errors):
        priorities = [self._to_priority(e) for e in errors]
        for (idx, p) in zip(indices, priorities):
            self.memory.update(idx, p)

    def _to_priority(self, error):
        return (error + self.epsilon) ** self.alpha
    
    def increase_beta(self):
        if self.beta < self.beta_end:
            self.beta = min(self.beta_end, self.beta + self.beta_increment)

    def __len__(self):
        return len(self.memory)