def load_data(self, buffer_filename):
        with open(buffer_filename, 'rb') as file:
            parameters = pickle.load(file)

        self.buffer, self.curr_write_idx, \
            self.available_samples, priorities = parameters
        self.sum_tree = SumTree(priorities)
Exemple #2
0
    def __init__(self, buffer_size, batch_size, seed):
        self.batch_size = batch_size
        self.experience = namedtuple(
            "Experience",
            field_names=["state", "action", "reward", "next_state", "done"])
        self.seed = random.seed(seed)

        self.tree = SumTree(buffer_size)
Exemple #3
0
    def __init__(self, capacity):
        '''
        Initializes PRB.

        Args:
            capacity: capacity of backing SumTree
        '''
        self.tree = SumTree(capacity)
        self.capacity = capacity
Exemple #4
0
 def __init__(self, capacity):
     # Making the tree
     """
     Remember that our tree is composed of a sum tree that contains the priority scores at his leaf
     And also a data array
     We don't use deque because it means that at each timestep our experiences change index by one.
     We prefer to use a simple array and to overwrite when the memory is full.
     """
     self.tree = SumTree(capacity)
Exemple #5
0
    def __init__(self, capacity, e=0.01, a=0.6):
        """        
        :param capacity: The maximum number of samples that can be stored
        :param e: Ensures that no sample has 0 priority
        :param a: 
        """
        self.capacity = capacity
        self.e = e
        self.a = a

        self.tree = SumTree(capacity)
Exemple #6
0
 def __init__(self, batch_size, buffer_size, seed):
     self.seed = random.seed(seed)
     self.epsilon = 0.01  # small amount to avoid zero priority
     self.alpha = 0.6  # [0~1] convert the importance of TD error to priority,
     # it is a trade-off between using priority and totally uniformly randomness
     # self.absolute_error_upper = 1.0 # clipped abs error (abs error is the absolute value of TD error)
     self.beta = 0.4  # importance-sampling, from initial value increasing to 1
     self.beta_increment_per_sampling = 0.001
     self.batch_size = batch_size
     self.buffer_size = buffer_size
     self.sumtree = SumTree(buffer_size)
Exemple #7
0
    def __init__(self, capacity, alpha):
        super(PrioritizedReplayMemory, self).__init__(capacity)
        self.max_priority = 1.0
        self.tree_ptr = 0
        self.alpha = alpha

        tree_capacity = 1
        while tree_capacity <= self.capacity:
            tree_capacity *= 2

        self.sum_tree = SumTree(tree_capacity)
        self.min_tree = MinTree(tree_capacity)
    def __init__(self, args):
        if args.load_path is None:
            self.model_params = ModelParameters(args)
            self.training_params = TrainingParameters(args)
            self.model, self.target_model = create_models(self.model_params)
        else:
            self.model_params = args.model_params
            self.training_params = args.training_params
            self.model, self.target_model = create_models(self.model_params, args.load_path+'/weights')

        self.replay = SumTree()
        # self.replay = []
        self.replay_index = None
    def __init__(self,
                 capacity,
                 epsilon=0.01,
                 alpha=0.6,
                 beta=0.4,
                 beta_increment=0.001):
        self.epsilon = epsilon
        self.alpha = alpha
        self.beta = beta
        self.beta_increment = beta_increment

        self.capacity = capacity
        self.tree = SumTree(self.capacity)
Exemple #10
0
    def __init__(self, max_size, window_size, input_shape):

        # set default sumtree
        self.tree = SumTree(max_size)
        self._max_size = max_size

        # dimension for how to store state and next state
        self._window_size = window_size
        self._WIDTH = input_shape[0]
        self._HEIGHT = input_shape[1]

        # hyperparmeters for priority probability
        self.e = 0.01
        self.a = 0.6
 def __init__(self, buffer_size, with_per = False):
     """ Initialization
     """
     if(with_per):
         # Prioritized Experience Replay
         self.alpha = 0.5
         self.epsilon = 0.01
         self.buffer = SumTree(buffer_size)
     else:
         # Standard Buffer
         self.buffer = deque()
     self.count = 0
     self.with_per = with_per
     self.buffer_size = buffer_size
Exemple #12
0
class Memory:  # stored as ( s, a, r, s_ ) in SumTree
    def __init__(self, capacity):
        self.tree = SumTree(capacity)
        self.max_p = 1
        self.e = 0.0
        self.a = 0.6

    def _getPriority(self, error):
        return (error + self.e)**self.a

    def length(self):
        return self.tree.write

    def add(self, sample, error):
        p = self._getPriority(error)
        self.tree.add(p, sample)

    def add_p(self, p, sample):
        self.tree.add(p, sample)

    def sample(self, n):
        batch = []
        idx_batch = []
        segment = self.tree.total() / n

        for i in range(n):
            a = segment * i
            b = segment * (i + 1)

            s = random.uniform(a, b)
            (idx, p, data) = self.tree.get(s)
            batch.append(data)
            idx_batch.append(idx)

        return batch, idx_batch

    def update(self, idx, error):
        p = self._getPriority(error)
        if p > self.max_p:
            self.max_p = p
        self.tree.update(idx, p)

    def update_batch(self, idx_batch, error_batch):
        p_batch = self._getPriority(error_batch)
        if np.max(p_batch) > self.max_p:
            self.max_p = np.max(p_batch)
        self.tree.update_batch(idx_batch, p_batch)
Exemple #13
0
    def testUpdating(self):
        tree = SumTree(4)
        tree.push(0, 1)
        tree.push(1, 2)
        tree.push(2, 3)
        tree.push(3, 4)
        tree.push(0, 5)

        self.assertEqual(tree._data[0], 2 + 3 + 4 + 5)
Exemple #14
0
class PrioritizedReplayMemory:
    e = 0.01
    a = 0.6

    def __init__(self, capacity):
        self.tree = SumTree(capacity)

    def _getPriority(self, error):
        return (error + self.e) ** self.a

    def add(self, error, sample):
        p = self._getPriority(error)
        self.tree.add(p, sample)

    def sample(self, n):
        batch = []
        segment = self.tree.total() / n

        for i in range(n):
            a = segment * i
            b = segment * (i+1)

            s = random.uniform(a, b)
            (idx, p, data) = self.tree.get(s)
            batch.append( (idx, data) )

        return batch

    def update(self, idx, error):
        p = self._getPriority(error)
        self.tree.update(idx, p)

    def isFull(self):
        return self.tree.isFull()
Exemple #15
0
class Memory(object):  # stored as ( s, a, r, s_ ) in SumTree
    """
    This SumTree code is modified version and the original code is from:
    https://github.com/jaara/AI-blog/blob/master/Seaquest-DDQN-PER.py
    """
    epsilon = 0.01  # small amount to avoid zero priority
    alpha = 0.6  # [0~1] convert the importance of TD error to priority
    beta = 0.4  # importance-sampling, from initial value increasing to 1
    beta_increment_per_sampling = 0.001
    abs_err_upper = 1.  # clipped abs error

    def __init__(self, capacity):
        self.tree = SumTree(capacity)

    def store(self, transition):
        max_p = np.max(self.tree.tree[-self.tree.capacity:])
        if max_p == 0:
            max_p = self.abs_err_upper
        self.tree.add(max_p, transition)   # set the max p for new p

    def sample(self, n):
        # b_idx, b_memory, ISWeights = np.empty((n,), dtype=np.int32), np.empty((n, self.tree.data[0].size)), np.empty((n, 1))
        b_idx, b_memory, ISWeights = deque(), deque(), deque()
        pri_seg = self.tree.total_p / n       # priority segment
        self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])  # max = 1

        max_prob = np.max(self.tree.tree[-self.tree.capacity:]) / self.tree.total_p     # for later calculate ISweight
        for i in range(n):
            a, b = pri_seg * i, pri_seg * (i + 1)
            v = np.random.uniform(a, b)
            idx, p, data = self.tree.get_leaf(v)
            prob = p / self.tree.total_p
            # ISWeights[i, 0] = np.power(prob/max_prob, -self.beta)
            ISWeights.append(np.power(prob/max_prob, -self.beta))
            # b_idx[i], b_memory[i, :] = idx, data
            b_idx.append(idx)
            b_memory.append(data)
        return np.array(list(b_idx)), np.array(list(b_memory)), np.reshape(np.array(list(ISWeights)), (n, 1))

    def batch_update(self, tree_idx, abs_errors):
        abs_errors += self.epsilon  # convert to abs and avoid 0
        clipped_errors = np.minimum(abs_errors, self.abs_err_upper)
        ps = np.power(clipped_errors, self.alpha)
        for ti, p in zip(tree_idx, ps):
            self.tree.update(ti, p)
Exemple #16
0
    def testRemoving(self):
        tree = SumTree(2)
        tree.push(0, 1)
        tree.push(1, 2)
        tree.pop(0)

        self.assertEqual(tree._data[0], 2)
Exemple #17
0
class PrioritizedER:
    e = 0.01
    a = 0.6
    beta = 0.4
    beta_increment_per_sampling = 0.001

    def __init__(self, capacity):
        self.tree = SumTree(capacity)
        self.capacity = capacity

    def _get_priority(self, error):
        return (abs(error) + self.e)**self.a

    def push(self, error, sample):
        p = self._get_priority(error)
        self.tree.add(p, sample)

    def sample(self, n):
        batch = []
        idxs = []
        segment = self.tree.total() / n
        priorities = []

        self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])

        for i in range(n):
            a = segment * i
            b = segment * (i + 1)

            s = random.uniform(a, b)
            (idx, p, data) = self.tree.get(s)
            if data == 0:
                p = priorities[-1]
                data = batch[-1]
                idx = idxs[-1]
                print(
                    'WARNING: transition value was 0, replaced it with the previous sampled transition'
                )
            priorities.append(p)
            batch.append(data)
            idxs.append(idx)

        sampling_probabilities = (priorities / self.tree.total()) + 10e-5
        is_weight = np.power(self.tree.n_entries * sampling_probabilities,
                             -self.beta)
        is_weight /= is_weight.max()

        return batch, idxs, is_weight

    def update(self, idx, error):
        p = self._get_priority(error)
        self.tree.update(idx, p)

    def __len__(self):
        return self.tree.n_entries
Exemple #18
0
class Memory:  # stored as ( s, a, r, s_ ) in SumTree
    e = 0.01
    a = 0.6

    def __init__(self, capacity):
        self.tree = SumTree(capacity)

    def _getPriority(self, error):
        return (error + self.e)**self.a

    def add(self, error, sample):
        p = self._getPriority(error)
        self.tree.add(p, sample)

    def sample(self, n):
        batch = []
        segment = self.tree.total() / n

        for i in range(n):
            a = segment * i
            b = segment * (i + 1)

            s = random.uniform(a, b)
            (idx, p, data) = self.tree.get(s)
            batch.append((idx, data))

        return batch

    def update(self, idx, error):
        p = self._getPriority(error)
        self.tree.update(idx, p)
    def __init__(self, max_capacity: int, batch_size: int,
                 state_size: tuple, action_size: tuple):
        self.state_size = state_size
        self.action_size = action_size
        self.buffer_size = max_capacity

        # buffer to store (s,a,r,s',d) tuples
        self.buffer = [(np.zeros(shape=self.state_size),
                        np.zeros(shape=self.action_size),
                        0.0,
                        np.zeros(shape=self.state_size),
                        0.0) for i in range(self.buffer_size)]

        # Initially all priorities are set to zero
        self.sum_tree = SumTree([0 for i in range(self.buffer_size)])

        self.curr_write_idx = 0
        self.available_samples = 0

        self.beta = 0.4  # Importance sampling factor
        self.alpha = 0.6  # priority factor
        self.min_priority = 0.01
        self.batch_size = batch_size
Exemple #20
0
    def testSampling(self):
        tree = SumTree(4)
        tree.push(0, 1)
        tree.push(1, 2)
        tree.push(2, 3)
        tree.push(3, 4)

        num_samples = int(1e5)
        res = tree.sample(num_samples)

        self.assertTrue(np.abs(np.sum(res == 3) / num_samples - 0.4) < 0.1)
        self.assertTrue(np.abs(np.sum(res == 2) / num_samples - 0.3) < 0.1)
        self.assertTrue(np.abs(np.sum(res == 1) / num_samples - 0.2) < 0.1)
Exemple #21
0
class Memory:
    """
    Stores transitions as (s, a, r, s_, done) tuples using a SumTree.
    Each sample is assigned a priority which affects retrieval
    """
    def __init__(self, capacity, e=0.01, a=0.6):
        """        
        :param capacity: The maximum number of samples that can be stored
        :param e: Ensures that no sample has 0 priority
        :param a: 
        """
        self.capacity = capacity
        self.e = e
        self.a = a

        self.tree = SumTree(capacity)

    def _getPriority(self, error):
        return (error + self.e)**self.a

    def add(self, error, sample):
        """
        Adds a new sample to the buffer
        :param error: The error associated with the sample
        :param sample: The sample to add 
        """
        p = self._getPriority(error)
        self.tree.add(p, sample)

    def sample(self, n):
        """
        Returns n samples from the buffer
        :param n: The number of samples to return
        """
        batch = []
        segment = self.tree.total() / n

        for i in range(n):
            a = segment * i
            b = segment * (i + 1)

            s = random.uniform(a, b)
            (idx, p, data) = self.tree.get(s)
            batch.append((idx, data))

        return batch

    def update(self, idx, error):
        p = self._getPriority(error)
        self.tree.update(idx, p)
class PrioritizedMemory(Memory):
    def __init__(self,
                 capacity,
                 epsilon=0.01,
                 alpha=0.6,
                 beta=0.4,
                 beta_increment=0.001):
        self.epsilon = epsilon
        self.alpha = alpha
        self.beta = beta
        self.beta_increment = beta_increment

        self.capacity = capacity
        self.tree = SumTree(self.capacity)

    def _compute_priority(self, loss):
        return (np.abs(loss) + self.epsilon)**self.alpha

    def push(self, *args):
        priority = self.tree.max()
        priority = 1 if priority <= 0 else priority
        self.tree.add(priority, Transition(*args))

    def sample(self, batch_size):
        batch = []
        indices = []
        weights = np.empty(batch_size, dtype='float32')
        self.beta += self.beta_increment
        beta = np.minimum(1., self.beta)
        total = self.tree.total()
        for i, r in enumerate(np.random.uniform(0, total, (batch_size, ))):
            index, priority, data = self.tree.get(r)
            batch.append(data)
            indices.append(index)
            weights[i] = (self.capacity * priority / total)**(-beta)

        return batch, indices, weights / weights.max()

    def update(self, index, loss):
        priority = self._compute_priority(loss)
        self.tree.update(index, priority)

    def __len__(self):
        return self.tree.n_entries
Exemple #23
0
class Memory:
    # Constants
    e = 0.01
    a = 0.0  #0.6

    # Initialize memory
    def __init__(self, capacity):
        self.tree = SumTree(capacity)
        self.capacity = capacity
        self.len = 0

    # Calculate error priority
    def getPriority(self, error):
        return (error + self.e)**self.a

    # Add sample to the memory
    def add(self, error, sample):
        p = self.getPriority(error)
        self.tree.add(p, sample)
        self.len = min(self.len + 1, self.capacity)

    # Generate 'n' random samples from the memory
    def sample(self, n):
        batch = []
        segment = self.tree.total() / n

        for i in range(n):
            a = segment * i
            b = segment * (i + 1)

            s = random.uniform(a, b)
            (idx, p, data) = self.tree.get(s)
            batch.append((idx, data))

        return batch

    # Number of current samples in memory
    def numberSamples(self):
        return self.len

    # Update priority of error
    def update(self, idx, error):
        p = self.getPriority(error)
        self.tree.update(idx, p)
class PERMemory(ReplayMemory):
    epsilon = 0.0001
    alpha = 0.6

    def __init__(self, CAPACITY):
        super(PERMemory, self).__init__(CAPACITY)
        self.tree = SumTree(CAPACITY)
        self.size = 0

    # Proportional prioritizationによるpriorityの計算
    def _getPriority(self, td_error):
        return (td_error + self.epsilon)**self.alpha

    def push(self, state, action, state_next, reward):
        """state, action, state_next, rewardをメモリに保存します"""
        self.size += 1

        priority = self.tree.max()
        if priority <= 0:
            priority = 1

        self.tree.add(priority, Transition(state, action, state_next, reward))

    def sample(self, batch_size):
        data_list = []
        indexes = []
        for rand in np.random.uniform(0, self.tree.total(), batch_size):
            (idx, _, data) = self.tree.get(rand)
            data_list.append(data)
            indexes.append(idx)

        return data_list, indexes

    def update(self, idx, td_error):
        priority = self._getPriority(td_error)
        self.tree.update(idx, priority)

    def __len__(self):
        return self.size
class Replay_Memory:
    def __init__(self):
        self.memory_len = 10000
        self.memory_bias = .01
        self.memory_pow = .6
        self.tree = SumTree(self.memory_len)

    def add(self, error, sample):
        priority = (error + self.memory_bias)**self.memory_pow
        self.tree.add(priority, sample)

    def sample(self, batch_size):
        """
         Get a sample batch of the replay memory
        Returns:
         batch: a batch with one sample from each segment of the memory
        """
        batch = []
        #we want one representative of all distribution-segments in the batch
        #e.g BATCH_SIZE=2: batch contains one sample from [min,median]
        #and from [median,max]
        segment = self.tree.total() / batch_size
        for i in range(batch_size):
            minimum = segment * i
            maximum = segment * (i + 1)
            s = random.uniform(minimum, maximum)
            (idx, _, data) = self.tree.get(s)
            batch.append((idx, data))
        return batch

    def update(self, idx, error):
        """
         Updates one entry in the replay memory
        Args:
         idx: the position of the outdated transition in the memory
         error: the newly calculated error
        """
        priority = (error + self.memory_bias)**self.memory_pow
        self.tree.update(idx, priority)
class DuelingModel:

    def __init__(self, args):
        if args.load_path is None:
            self.model_params = ModelParameters(args)
            self.training_params = TrainingParameters(args)
            self.model, self.target_model = create_models(self.model_params)
        else:
            self.model_params = args.model_params
            self.training_params = args.training_params
            self.model, self.target_model = create_models(self.model_params, args.load_path+'/weights')

        self.replay = SumTree()
        # self.replay = []
        self.replay_index = None

    def replay_batch(self):
        batch_size = self.training_params.batch_size
        ixs = []
        if self.training_params.prioritize:
            sum_all = self.replay.tree[0].sum
            for i in range(batch_size):
                sample_value = sum_all / batch_size * (i + np.random.rand())
                ixs.append(self.replay.sample(sample_value))
            holders = [self.replay.tree[ix].pointer for ix in ixs]
            return holders, ixs
        else:
            for i in range(batch_size):
                ixs.append(self.replay.sample_random())
            return [self.replay.tree[ix].pointer for ix in ixs], ixs

        # ixs = np.random.choice(len(self.replay), size=batch_size)
        # return [self.replay[ix] for ix in ixs]

    def train_on_batch(self):
        batch, batch_ixs = self.replay_batch()
        pre_sample = np.array([h.s_t for h in batch])
        post_sample = np.array([h.s_tp1 for h in batch])
        qpre = self.model.predict(pre_sample)
        qpost = self.target_model.predict(post_sample)

        q1 = np.zeros(qpre.shape[0])
        q2 = np.zeros_like(q1)

        for i in xrange(len(batch)):
            q1[i] = qpre[i, batch[i].a_t]  # XXX: max instead?
            if batch[i].last:
                qpre[i, batch[i].a_t] = batch[i].r_t
            else:
                qpre[i, batch[i].a_t] = batch[i].r_t + self.training_params.gamma * np.amax(qpost[i])
            q2[i] = qpre[i, batch[i].a_t]

        delta = q1 - q2
        w = self.get_p_weights(delta, batch, batch_ixs)

        self.model.train_on_batch(pre_sample, qpre, sample_weight=w)

    def add_to_replay(self, h, training_started=False):
        if np.isnan(h.r_t):
            print 'nan in reward (!)'
            return
        if training_started:
            h.delta = self.replay.tree[0].sum/self.training_params.batch_size  # make sure it'll be sampled once
        else:
            h.delta = 1.
        if len(self.replay.tree) >= self.training_params.replay_size:
            if self.replay_index is None:
                self.replay_index = self.replay.last_ixs()
            ix = self.replay_index.next()
            self.replay.tree[ix].pointer = h
            self.replay.tree[ix].sum = h.delta
            self.replay.update(ix)
        else:
            self.replay.add_node(h)

    def get_p_weights(self, delta, batch, batch_ixs):
        """Output weights for prioritizing bias compensation"""
        # if self.is_rnn:
        #     s = s[:, -1]
        if self.training_params.prioritize:
            p_total = self.replay.tree[0].sum
            p = np.array([h.delta for h in batch]) / p_total
            w = 1. / p
            w /= np.max(w)
            for ix, h, d in zip(batch_ixs, batch, delta):
                h.delta = np.nan_to_num(np.abs(d))  # catch nans
                self.replay.update(ix)
        else:
            w = np.ones(len(batch))
        return w

    def get_delta(self, batch):
        pre_sample = np.array([h.s_t for h in batch])
        post_sample = np.array([h.s_tp1 for h in batch])
        qpre = self.model.predict(pre_sample)
        qpost = self.target_model.predict(post_sample)
        q1 = np.zeros(qpre.shape[0])
        q2 = np.zeros_like(q1)
        for i in xrange(len(batch)):
            q1[i] = qpre[i, batch[i].a_t]  # XXX: max instead?
            if batch[i].last:
                qpre[i, batch[i].a_t] = batch[i].r_t
            else:
                qpre[i, batch[i].a_t] = batch[i].r_t + self.training_params.gamma * np.amax(qpost[i])
            q2[i] = qpre[i, batch[i].a_t]
        delta = q1 - q2
        return delta

    def heap_update(self):
        """Every n steps, recalculate deltas in the sumtree"""
        print 'SumTree pre-update:', self.replay.tree[0].sum
        last_ixs = self.replay.last_ixs(True)
        while True:
            if len(last_ixs) == 0:
                break
            if len(last_ixs) < 10000:
                ixs = last_ixs
                last_ixs = []
            else:
                ixs = last_ixs[:10000]
                last_ixs = last_ixs[10000:]
            batch = [self.replay.tree[ix].pointer for ix in ixs]
            delta = self.get_delta(batch)
            self.get_p_weights(delta, batch, ixs)
        print 'SumTree post-update:', self.replay.tree[0].sum
        print 'SumTree updated'

    def save(self, save_path, folder_name):
        if not os.path.exists(save_path + folder_name):
            os.makedirs(save_path + folder_name)

        self.target_model.save_weights(save_path + folder_name + '/weights', overwrite=True)
        info = self.model_params, self.training_params
        d_file = open(save_path + folder_name + '/model_params', 'wr')
        cPickle.dump(info, d_file)
        d_file.close()
class Memory:
    def __init__(self, max_capacity: int, batch_size: int,
                 state_size: tuple, action_size: tuple):
        self.state_size = state_size
        self.action_size = action_size
        self.buffer_size = max_capacity

        # buffer to store (s,a,r,s',d) tuples
        self.buffer = [(np.zeros(shape=self.state_size),
                        np.zeros(shape=self.action_size),
                        0.0,
                        np.zeros(shape=self.state_size),
                        0.0) for i in range(self.buffer_size)]

        # Initially all priorities are set to zero
        self.sum_tree = SumTree([0 for i in range(self.buffer_size)])

        self.curr_write_idx = 0
        self.available_samples = 0

        self.beta = 0.4  # Importance sampling factor
        self.alpha = 0.6  # priority factor
        self.min_priority = 0.01
        self.batch_size = batch_size

    def __len__(self):
        return self.available_samples

    def record(self, experience: tuple, priority: float):
        # add the experience to the buffer
        self.buffer[self.curr_write_idx] = experience

        # update the priority of this experience in the sum tree
        self.update(self.curr_write_idx, priority)

        self.curr_write_idx += 1

        if self.curr_write_idx >= self.buffer_size:
            self.curr_write_idx = 0

        if self.available_samples < self.buffer_size:
            self.available_samples += 1  # max value = self.buffer_size

    def adjust_priority(self, priority: float):
        return np.power(priority + self.min_priority, self.alpha)

    def update(self, idx: int, priority: float):
        self.sum_tree.update(self.sum_tree.leaf_nodes[idx],
                             self.adjust_priority(priority))

    def sample(self):
        sampled_idxs = []
        is_weights = []  # importance sampling weights
        sample_no = 0
        while sample_no < self.batch_size:
            sample_val = np.random.uniform(0, self.sum_tree.root_node.value)
            sample_node = self.sum_tree.retrieve(sample_val, self.sum_tree.root_node)

            # check if this is a valid idx
            if sample_node.idx < self.available_samples:
                sampled_idxs.append(sample_node.idx)
                p = sample_node.value / (self.sum_tree.root_node.value + 1e-3)  # avoid singularity
                is_weights.append(self.available_samples * p)  # give equal weights
            sample_no += 1
        # while loop ends here
        # apply beta factor and normalise so that maximum is_weight < 1
        is_weights = np.array(is_weights)
        is_weights = np.power(is_weights, -self.beta)
        is_weights = is_weights / np.max(is_weights)  # normalize to (0, 1)

        # load states and next_states
        state_batch = []
        action_batch = []
        reward_batch = []
        next_state_batch = []
        done_batch = []
        for i in range(len(sampled_idxs)):
            state_batch.append(self.buffer[sampled_idxs[i]][0])
            action_batch.append(self.buffer[sampled_idxs[i]][1])
            reward_batch.append(self.buffer[sampled_idxs[i]][2])
            next_state_batch.append(self.buffer[sampled_idxs[i]][3])
            done_batch.append(self.buffer[sampled_idxs[i]][4])
        return state_batch, action_batch, reward_batch,\
               next_state_batch, done_batch, sampled_idxs, is_weights

    def save_priorities_txt(self, filename):
        priorities = self.sum_tree.get_priorities()

        with open(filename, 'w') as file:
            for i in range(self.buffer_size):
                file.write('{}\t{}\n'.format(i, priorities[i]))

    def save_data(self, buffer_filename):
        # get priorities
        priorities = self.sum_tree.get_priorities()
        parameters = (
            self.buffer,
            self.curr_write_idx,
            self.available_samples,
            priorities
        )
        with open(buffer_filename, 'wb') as file:
            pickle.dump(parameters, file)

    def load_data(self, buffer_filename):
        with open(buffer_filename, 'rb') as file:
            parameters = pickle.load(file)

        self.buffer, self.curr_write_idx, \
            self.available_samples, priorities = parameters
        self.sum_tree = SumTree(priorities)
Exemple #28
0
 def __init__(self, capacity):
     self.tree = SumTree(capacity)
     self.capacity = capacity
     self.len = 0
Exemple #29
0
 def __init__(self, capacity):
     self.tree = SumTree(capacity)
     self.max_p = 1
     self.e = 0.0
     self.a = 0.6
 def clear(self):
     """ Clear buffer / Sum Tree
     """
     if(self.with_per): self.buffer = SumTree(buffer_size)
     else: self.buffer = deque()
     self.count = 0
class MemoryBuffer(object):
    """ Memory Buffer Helper class for Experience Replay
    using a double-ended queue or a Sum Tree (for PER)
    """
    def __init__(self, buffer_size, with_per = False):
        """ Initialization
        """
        if(with_per):
            # Prioritized Experience Replay
            self.alpha = 0.5
            self.epsilon = 0.01
            self.buffer = SumTree(buffer_size)
        else:
            # Standard Buffer
            self.buffer = deque()
        self.count = 0
        self.with_per = with_per
        self.buffer_size = buffer_size

    def memorize(self, state, action, reward, done, new_state, achieved_goal, goal, error=None):
        """ Save an experience to memory, optionally with its TD-Error
        """
        experience = (state, action, reward, done, new_state, achieved_goal, goal, error)
        if(self.with_per):
            priority = self.priority(error[0])
            self.buffer.add(priority, experience)
            self.count += 1
        else:
            # Check if buffer is already full
            if self.count < self.buffer_size:
                self.buffer.append(experience)
                self.count += 1
            else:
                self.buffer.popleft()
                self.buffer.append(experience)


    def priority(self, error):
        """ Compute an experience priority, as per Schaul et al.
        """
        return (error + self.epsilon) ** self.alpha

    def size(self):
        """ Current Buffer Occupation
        """
        return self.count

    def sample_batch(self, batch_size):
        """ Sample a batch, optionally with (PER)
        """
        batch = []
        # Sample using prorities
        if(self.with_per):
            T = self.buffer.total() // batch_size
            for i in range(batch_size):
                a, b = T * i, T * (i + 1)
                s = random.uniform(a, b)
                idx, error, data = self.buffer.get(s)
                batch.append((*data, idx))
            idx = np.array([i[7] for i in batch])
        # Sample randomly from Buffer
        elif self.count < batch_size:
            idx = None
            batch = random.sample(self.buffer, self.count)
        else:
            idx = None
            batch = random.sample(self.buffer, batch_size)

        # Return a batch of experience
        s_batch = np.array([i[0] for i in batch])
        a_batch = np.array([i[1] for i in batch])
        r_batch = np.array([i[2] for i in batch])
        d_batch = np.array([i[3] for i in batch])
        new_s_batch = np.array([i[4] for i in batch])
        ag_batch = np.array([i[5] for i in batch])
        g_batch = np.array([i[6] for i in batch])
        return s_batch, a_batch, r_batch, d_batch, new_s_batch, ag_batch, g_batch, idx

    def update(self, idx, new_error):
        """ Update priority for idx (PER)
        """
        self.buffer.update(idx, self.priority(new_error))

    def clear(self):
        """ Clear buffer / Sum Tree
        """
        if(self.with_per): self.buffer = SumTree(buffer_size)
        else: self.buffer = deque()
        self.count = 0
Exemple #32
0
 def __init__(self, capacity):
     self.tree = SumTree(capacity)
Exemple #33
0
from sumtree import SumTree




tree = SumTree(memory_size=10)
p = 1
for i in range(p):
    tree.add(10000, (1, 1, 1, 1, 1))
print("tree",tree.tree)
print("transition",tree.transitions)