Beispiel #1
0
class PrioritisedMemory(object):
    def __init__(self, alpha, beta, beta_end, epsilon, num_steps, replay_size):

        self.alpha = alpha
        self.beta_start = beta
        self.beta_end = beta_end
        self.beta = beta
        self.epsilon = epsilon
        self.num_steps = num_steps

        self.memory = SumTree(replay_size)
        self.replay_size = replay_size

    def proprotional_priority(self, td_error):

        return (np.abs(td_error) + self.epsilon)**self.alpha

    def add_memory(self, td_error, data):

        priority = self.proprotional_priority(td_error)

        self.memory.add_memory(data, priority)

        self.beta = np.min([
            1.0, self.beta + (self.beta_end - self.beta_start) / self.num_steps
        ])

    def update_priority(self, index, td_error):

        new_priority = self.proprotional_priority(td_error)
        self.memory.update_priority(index, new_priority)

    def minibatch_sample(self, minibatch_size):

        samples = []
        priorities = []
        priority_indexes = []

        interval = self.memory.priority_total() / minibatch_size

        for i in range(minibatch_size):

            sample = np.random.uniform(i * interval, (i + 1) * interval)

            priority_index, priority, data = self.memory.get(sample)

            samples.append(data)

            priorities.append(priority)

            priority_indexes.append(priority_index)

        sampling_probabilities = priorities / self.memory.priority_total()
        importance_weights = np.power(
            self.memory.replay_size * sampling_probabilities, -self.beta)
        importance_weights /= np.max(is_weight)

        return priority_indexes, samples, importance_weights
Beispiel #2
0
class Memory:
    def __init__(self,
                 tree_memory_length,
                 error_multiplier=0.01,
                 alpha=0.6,
                 beta=0.4,
                 beta_increment_per_sample=0.001):
        self.tree = SumTree(tree_memory_length)
        self.tree_memory_length = tree_memory_length
        self.error_multiplier = error_multiplier
        self.per_alpha = alpha
        self.per_beta_init = beta
        self.beta_increment_per_sample = beta_increment_per_sample

    def _get_priority(self, error):
        return (np.abs(error) + self.error_multiplier)**self.per_alpha

    def add_sample_to_tree(self, error, sample):
        priority = self._get_priority(error)
        self.tree.add(priority, sample)

    def sample_tree(self, num_samples):
        batch = []
        idxs = []
        segment = self.tree.sum_of_tree() / num_samples
        priorities = []

        self.beta = np.min(
            [1.0, self.per_beta_init + self.beta_increment_per_sample])

        for i in range(num_samples):
            a = segment * i
            b = segment * (i + 1)

            sample = random.uniform(a, b)
            idx, priority, data = self.tree.get_sample(sample)

            priorities.append(priority)
            batch.append(data)
            idxs.append(idx)

        sampling_prob = priorities / self.tree.sum_of_tree()
        is_weight = np.power(self.tree.num_entries * sampling_prob, -self.beta)
        is_weight /= is_weight.max()

        return batch, idxs, is_weight

    def update_tree(self, idx, error):
        priority = self._get_priority(error)
        self.tree.update_priority(idx, priority)