Exemplo n.º 1
0
class ReplayMemory:
    def __init__(self, memory_size):
        self.memory_size = memory_size
        self.memory = SumTree(memory_size)
        self.epsilon = 0.0001  # small amount to avoid zero priority
        self.alpha = 0.6  # adj_pri = pri^alpha
        self.beta = 0.4  # importance-sampling, from initial value increasing to 1
        self.beta_max = 1
        self.beta_increment_per_sampling = 0.001
        self.abs_err_upper = 1.  # clipped td error

    def add(self, row):
        max_p = np.max(
            self.memory.tree[-self.memory.capacity:])  # max adj_pri of leaves
        if max_p == 0:
            max_p = self.abs_err_upper
        self.memory.add(max_p, row)  # set the max adj_pri for new adj_pri

    def get_batch(self, batch_size):
        leaf_idx, batch_memory, ISWeights = np.empty(
            batch_size,
            dtype=np.int32), np.empty(batch_size,
                                      dtype=object), np.empty(batch_size)
        pri_seg = self.memory.total_p / batch_size  # adj_pri segment
        self.beta = np.min(
            [self.beta_max,
             self.beta + self.beta_increment_per_sampling])  # max = 1

        # Pi = Prob(i) = softmax(priority(i)) = adj_pri(i) / ∑_i(adj_pri(i))
        # ISWeight = (N*Pj)^(-beta) / max_i[(N*Pi)^(-beta)] = (Pj / min_i[Pi])^(-beta)
        min_prob = np.min(
            self.memory.tree[self.memory.capacity - 1:self.memory.capacity -
                             1 + self.memory.counter]) / self.memory.total_p
        for i in range(batch_size):
            # sample from each interval
            a, b = pri_seg * i, pri_seg * (i + 1)  # interval
            v = np.random.uniform(a, b)
            idx, p, data = self.memory.get_leaf(v)
            prob = p / self.memory.total_p
            ISWeights[i] = np.power(prob / min_prob, -self.beta)
            leaf_idx[i], batch_memory[i] = idx, data
        return leaf_idx, batch_memory, ISWeights

    def update_sum_tree(self, tree_idx, td_errors):
        for ti, td_error in zip(tree_idx, td_errors):
            p = self._calculate_priority(td_error)
            self.memory.update(ti, p)

    def _calculate_priority(self, td_error):
        priority = abs(td_error) + self.epsilon
        clipped_pri = np.minimum(priority, self.abs_err_upper)
        return np.power(clipped_pri, self.alpha)

    @property
    def length(self):
        return self.memory.counter

    def load_memory(self, memory):
        self.memory = memory

    def get_memory(self):
        return self.memory
Exemplo n.º 2
0
class Memory(object):
    def __init__(self,
                 capacity,
                 state_size=37,
                 epsilon=0.001,
                 alpha=0.4,
                 beta=0.3,
                 beta_increment_per_sampling=0.001,
                 abs_err_upper=1):
        self.tree = SumTree(capacity)
        self.epsilon = epsilon  # Avoid 0 priority and hence a do not give a chance for the priority to be selected stochastically
        self.alpha = alpha  # Vary priority vs randomness. alpha = 0 pure uniform randomnes. Alpha = 1, pure priority
        self.beta = beta  # importance-weight-sampling, from small to big to give more importance to corrections done towards the end of the training
        self.beta_increment_per_sampling = 0.001
        self.abs_err_upper = 1  # clipped abs error
        self.state_size = state_size

    # Save experience in memory
    def store(self, state, action, reward, next_state, done):
        transition = [state, action, reward, next_state, done]
        max_p = np.max(self.tree.tree[-self.tree.capacity:])

        # In case of no priority, we set abs error to 1
        if max_p == 0:
            max_p = self.abs_err_upper
        self.tree.add(max_p, transition)  # set the max p for new p

    # Sample n amount of experiences using prioritized experience replay
    def sample(self, n):
        b_idx = np.empty((n, ), dtype=np.int32)
        states = np.empty((n, self.state_size))
        actions = np.empty((n, ))
        rewards = np.empty((n, ))
        next_states = np.empty((n, self.state_size))
        dones = np.empty((n, ))
        ISWeights = np.empty((n, ))  # IS -> Importance Sampling

        pri_seg = self.tree.total_p / n  # priority segment
        self.beta = np.min([
            1., self.beta + self.beta_increment_per_sampling
        ])  # Increase the importance of the sampling for ISWeights

        # min_prob = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total_p  # for later calculate ISweight

        for i in range(n):
            a, b = pri_seg * i, pri_seg * (i + 1)
            v = np.random.uniform(a, b)
            idx, p, data = self.tree.get_leaf(v)
            prob = p / self.tree.total_p
            ISWeights[i] = np.power(prob, -self.beta)
            b_idx[i] = idx
            states[i, :] = data[0]
            actions[i] = data[1]
            rewards[i] = data[2]
            next_states[i, :] = data[3]
            dones[i] = data[4]

        states = torch.from_numpy(np.vstack(states)).float().to(device)
        actions = torch.from_numpy(np.vstack(actions)).long().to(device)
        rewards = torch.from_numpy(np.vstack(rewards)).float().to(device)
        next_states = torch.from_numpy(
            np.vstack(next_states)).float().to(device)
        dones = torch.from_numpy(np.vstack(dones).astype(
            np.uint8)).float().to(device)
        ISWeights = torch.from_numpy(np.vstack(ISWeights)).float().to(device)

        return b_idx, states, actions, rewards, next_states, dones, ISWeights

    # Update the priorities according to the new errors
    def batch_update(self, tree_idx, abs_errors):
        abs_errors += self.epsilon  # convert to abs and avoid 0
        clipped_errors = np.minimum(abs_errors, self.abs_err_upper)
        ps = np.power(clipped_errors, self.alpha)
        for ti, p in zip(tree_idx, ps):
            self.tree.update(ti, p)

    def __len__(self):
        return self.tree.length()
Exemplo n.º 3
0
class PriorityMemory(SimpleMemory):
    PER_e = 0.01  # Hyperparameter that we use to avoid some experiences to have 0 probability of being taken
    PER_a = 0.6  # Hyperparameter that we use to make a tradeoff between taking only exp with high priority and sampling randomly
    PER_b = 0.4  # importance-sampling, from initial value increasing to 1
    PER_b_increment_per_sampling = 0.001
    absolute_error_upper = 1.  # clipped abs error

    def __init__(self, obs_dim, act_dim, size, act_dtype):
        SimpleMemory.__init__(self, obs_dim, act_dim, size, act_dtype)
        self.tree = SumTree(size)
        self.tree_lock = Lock()

    def store(self, obs, act, rew, next_obs, done):
        # Find the max priority
        max_priority = np.max(self.tree.tree[-self.tree.capacity:])

        # If the max priority = 0 we can't put priority = 0 since this exp will never have a chance to be selected
        # So we use a minimum priority
        if max_priority == 0:
            max_priority = self.absolute_error_upper

        insertion_pos = super().store(obs, act, rew, next_obs, done)
        self.tree_lock.acquire()
        insertion_pos_tree = self.tree.add(
            max_priority)  # set the max p for new p
        self.tree_lock.release()
        assert insertion_pos == insertion_pos_tree

    def sample_batch(self, batch_size):
        #idxs = np.random.randint(0, self._size, size=batch_size)
        #return self.obs1_buf[idxs],self.acts_buf[idxs],self.rews_buf[idxs],self.obs2_buf[idxs],self.done_buf[idxs]

        mem_idxs, tree_idxs, b_ISWeights =\
            np.empty((batch_size,), dtype=np.int32),\
            np.empty((batch_size,), dtype=np.int32),\
            np.empty((batch_size, 1), dtype=np.float32)

        # Calculate the priority segment
        # Here, as explained in the paper, we divide the Range[0, ptotal] into n ranges
        priority_segment = self.tree.total_priority / batch_size  # priority segment

        # Here we increasing the PER_b each time we sample a new minibatch
        self.PER_b = np.min(
            [1., self.PER_b + self.PER_b_increment_per_sampling])  # max = 1

        # Calculating the max_weight
        #print('### pp: {}'.format(-self.tree.capacity))
        #print('### pp: {}'.format(self.tree.tree[-self.tree.capacity:]))
        #print('### pp: {}'.format(np.min(self.tree.tree[-self.tree.capacity:])))
        #p_min = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total_priority
        p_min = self.tree.p_min
        assert p_min > 0
        max_weight = (p_min * batch_size)**(-self.PER_b)
        assert max_weight > 0

        for i in range(batch_size):
            """
            A value is uniformly sample from each range
            """
            a, b = priority_segment * i, priority_segment * (i + 1)
            value = np.random.uniform(a, b)
            """
            Experience that correspond to each value is retrieved
            """
            assert self.tree.data_pointer > 0
            self.tree_lock.acquire()
            index, priority = self.tree.get_leaf(value)
            self.tree_lock.release()
            assert priority > 0, "### index {}".format(index)

            #P(j)
            sampling_probabilities = priority / self.tree.total_priority

            #  IS = (1/N * 1/P(i))**b /max wi == (N*P(i))**-b  /max wi
            b_ISWeights[i, 0] = batch_size * sampling_probabilities
            assert b_ISWeights[i, 0] > 0
            b_ISWeights[i, 0] = np.power(b_ISWeights[i, 0], -self.PER_b)
            b_ISWeights[i, 0] = b_ISWeights[i, 0] / max_weight

            mem_idxs[i] = index - self.max_size + 1
            tree_idxs[i] = index
            #assert b_idx[i] < self.max_size , "{} and {}".format(b_idx[i], self.max_size)
        return self.obs1_buf[mem_idxs],\
            self.acts_buf[mem_idxs],\
            self.rews_buf[mem_idxs],\
            self.obs2_buf[mem_idxs],\
            self.done_buf[mem_idxs],\
            tree_idxs,\
            b_ISWeights

    """
    Update the priorities on the tree
    """

    def batch_update(self, tree_idx, abs_errors):
        abs_errors += self.PER_e  # convert to abs and avoid 0
        clipped_errors = np.minimum(abs_errors, self.absolute_error_upper)
        ps = np.power(clipped_errors, self.PER_a)

        self.tree_lock.acquire()
        for ti, p in zip(tree_idx, ps):
            self.tree.update(ti, p)
        self.tree_lock.release()
Exemplo n.º 4
0
class Memory(object):
    """
    This SumTree code is modified version and the original code is from:
    https://github.com/jaara/AI-blog/blob/master/Seaquest-DDQN-PER.py
    """
    beta = MEMORY_BETA

    def __init__(self):
        self.limit = MEMORY_CAPACITY
        self.err_tree = SumTree(MEMORY_CAPACITY)
        self.action_shape = (0, MEMORY_ACTION_CNT)
        self.reward_shape = (0, MEMORY_REWARD_CNT)
        self.terminal_shape = self.action_shape
        self.observation_shape = (0, MEMORY_CRITIC_FEATURE_NUM)
        self.store_times = 0
        self.Transition = namedtuple(
            'Transition',
            ('state', 'action', 'reward', 'next_state', 'terminal'))

    def size(self):
        return self.limit if self.store_times > self.limit else self.store_times

    def sample(self, batch_size):
        idxes = np.empty(self.reward_shape, dtype=np.int32)
        isw = np.empty(self.reward_shape, dtype=np.float32)
        obs0 = np.empty(self.observation_shape, dtype=np.float32)
        obs1 = np.empty(self.observation_shape, dtype=np.float32)
        actions = np.empty(self.action_shape, dtype=np.float32)
        rewards = np.empty(self.reward_shape, dtype=np.float32)
        terminals = np.empty(self.terminal_shape, dtype=np.bool)
        nan_state = np.array([np.nan] * self.observation_shape[1])

        self.beta = np.min([1., self.beta + MEMORY_BETA_INC_RATE])  # max = 1
        max_td_err = np.max(self.err_tree.tree[-self.err_tree.capacity:])
        idx_set = set()
        for i in range(
                batch_size * 2
        ):  # sample maximum batch_size * 2 times to get batch_size different instances
            v = np.random.uniform(0, self.err_tree.total_p)
            idx, td_err, trans = self.err_tree.get_leaf(v)
            if batch_size == len(idx_set):
                break
            if idx not in idx_set:
                idx_set.add(idx)
            else:
                continue
            if (trans.state == 0).all():
                continue
            idxes = np.row_stack((idxes, np.array([idx])))
            isw = np.row_stack((isw,
                                np.array([
                                    np.power(
                                        self._getPriority(td_err) / max_td_err,
                                        -self.beta)
                                ])))
            obs0 = np.row_stack((obs0, trans.state))
            obs1 = np.row_stack(
                (obs1,
                 nan_state if trans.terminal.all() else trans.next_state))
            actions = np.row_stack((actions, trans.action))
            rewards = np.row_stack((rewards, trans.reward))
            terminals = np.row_stack((terminals, trans.terminal))

        result = {
            'obs0': array_min2d(obs0),
            'actions': array_min2d(actions),
            'rewards': array_min2d(rewards),
            'obs1': array_min2d(obs1),
            'terminals': array_min2d(terminals),
        }

        return idxes, result, isw

    def _getPriority(self, error):
        return (error + EPSILON)**MEMORY_ALPHA

    def append(self, obs0, action, reward, obs1, terminal, err, training=True):
        if not training:
            return
        trans = self.Transition(obs0, action, reward, obs1, terminal)
        self.err_tree.add(self._getPriority(err), trans)
        self.store_times += 1

    def batch_update(self, tree_idx, errs):
        errs = np.abs(errs) + EPSILON  # convert to abs and avoid 0
        ps = np.power(errs, MEMORY_ALPHA)
        for ti, p in zip(tree_idx, ps):
            self.err_tree.update(ti, p[0])

    @property
    def nb_entries(self):
        return self.store_times