Пример #1
0
 def test_stress(self):
     n = 1000
     epochs = 100
     arr = np.zeros(n)
     tree = SegmentTree(0, n)
     tree.update(0, n, 0)
     for i in range(epochs):
         l, r = np.random.randint(0, n, 2)
         l, r = sorted([l, r])
         r += 1
         val = np.random.random()
         tree.update(l, r, val)
         arr[l: r] = val
         for i in range(n):
             self.assertEqual(arr[i], tree.query(i))
Пример #2
0
class ReplayMemory:
    def __init__(self, max_memory=1000):
        self.max_memory = max_memory
        self.memory = SegmentTree(max_memory)
        self._count = 0

    @property
    def count(self):
        return self._count

    def add_memory(self, state_input, best_action, reward, done,
                   next_state_input, td):
        data = [state_input, best_action, reward, done, next_state_input]

        self.memory.add(td, data)

        if self._count <= self.max_memory:
            self._count += 1

    def get_memory(self, batch_size):
        segment = self.memory.total / batch_size

        batch_tree_index = []
        tds = []
        batch = []

        for i in range(batch_size):
            a = segment * i
            b = segment * (i + 1)
            segment = random.uniform(a, b)
            tree_index, td, data = self.memory.get(segment)
            batch_tree_index.append(tree_index)
            tds.append(td)
            batch.append(data)

        return batch_tree_index, tds, batch

    def update_memory(self, tree_indexes, tds):
        for i in range(len(tree_indexes)):
            self.memory.update(tree_indexes[i], tds[i])
Пример #3
0
from segment_tree import SegmentTree 

arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] 

t = SegmentTree(arr) 

a = t.query(2, 9, "max") 
print("The maximum value of this range is : ", a) 


a = t.query(2, 9, "min") 
print("The minimum value of this range is : ", a) 

a = t.query(2, 7, "sum") 
print("The sum of this range is : ", a) 

t.update(2, 25) 

print("The updated array is : ", arr) 
Пример #4
0
class PrioritizedReplayMemory:
    def __init__(self, args, capacity):
        self.capacity = capacity
        self.discount = args.gamma
        self.priority_weight = args.priority_weight
        self.priority_exponent = args.priority_exponent
        self.absolute_error_upper = args.absolute_error_upper
        self.t = 0  # Internal episode timestep counter
        self.tree = SegmentTree(
            capacity
        )  # Store experiences in a wrap-around cyclic buffer within a sum tree for querying priorities
        self.priority_weight_increase = (1 -
                                         args.priority_weight) / self.capacity

    # Adds state and action at time t, reward and done at time t + 1
    def append(self, state, action, reward, next_state, done):
        self.tree.append(
            Experience(state, action, reward, next_state, done),
            self.tree.max)  # Store new transition with maximum priority
        self.t = 0 if done else self.t + 1  # Start new episodes with t = 0

    def _get_sample_from_segment(self, segment, i):
        valid = False
        while not valid:
            sample = np.random.uniform(
                i * segment, (i + 1) *
                segment)  # Uniformly sample an element from within a segment
            prob, idx, tree_idx = self.tree.find(
                sample
            )  # Retrieve sample from tree with un-normalised probability
            # Resample if transition straddled current index or probability 0
            if prob != 0:
                valid = True  # Note that conditions are valid but extra conservative around buffer index 0

        experience = self.tree.get(idx)

        return prob, idx, tree_idx, experience

    def sample(self, batch_size):
        self.priority_weight = min(
            self.priority_weight + self.priority_weight_increase, 1)
        p_total = self.tree.total(
        )  # Retrieve sum of all priorities (used to create a normalised probability distribution)
        segment = p_total / batch_size  # Batch size number of segments, based on sum over all probabilities

        batch = [
            self._get_sample_from_segment(segment, i)
            for i in range(batch_size)
        ]  # Get batch of valid samples
        probs, idxs, tree_idxs, experiences = zip(*batch)

        states = torch.from_numpy(
            np.vstack([exp.state for exp in experiences if exp is not None
                       ])).to(device=device).to(dtype=torch.float32)
        actions = torch.from_numpy(
            np.vstack([exp.action for exp in experiences if exp is not None
                       ])).to(device=device).to(dtype=torch.long)
        rewards = torch.from_numpy(
            np.vstack([exp.reward for exp in experiences if exp is not None
                       ])).to(device=device).to(dtype=torch.float32)
        next_states = torch.from_numpy(
            np.vstack([
                exp.next_state for exp in experiences if exp is not None
            ])).to(device=device).to(dtype=torch.float32)
        dones = torch.from_numpy(
            np.vstack([exp.done for exp in experiences if exp is not None
                       ])).to(device=device).to(dtype=torch.float32)

        probs = np.array(
            probs,
            dtype=np.float32) / p_total  # Calculate normalised probabilities
        capacity = self.capacity if self.tree.full else self.tree.index
        weights = (
            capacity * probs
        )**-self.priority_weight  # Compute importance-sampling weights w
        weights = torch.tensor(
            weights / weights.max(), dtype=torch.float32, device=device
        )  # Normalise by max importance-sampling weight from batch
        return tree_idxs, states, actions, rewards, next_states, dones, weights

    def update_priorities(self, idxs, priorities):
        # priorities = errors
        clipped_errors = np.minimum(priorities, self.absolute_error_upper)
        clipped_errors = np.power(clipped_errors, self.priority_exponent)
        for idx, priority in zip(idxs, clipped_errors):
            self.tree.update(idx, priority)

    def __len__(self):
        return len(self.tree)
Пример #5
0
#autor: Manjarrez Hernandez Raul
#carrera: Ingenieria en Sistemas Computacionales
from segment_tree import SegmentTree
"""una matriz con algunos elementos
aquí estamos ajustando nuestra matriz en el árbol de segmentos donde t es
tomado como objeto del árbol de segmentos"""
arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]

#t se utilizará para realizar operaciones en ese segmento
t = SegmentTree(arr)

#aquí estamos encontrando valor de número máximo en un rango
a = t.query(0, 10, "max")
print("El Valor Maximo del Rango es: ", a)

#aquí estamos encontrando el valor de número mínimo en un rango
a = t.query(4, 10, "min")
print("El Valor Minimo del Rango es: ", a)

#aquí estamos encontrando el valor de suma de un rango
a = t.query(0, 3, "sum")
print("La suma del Rango es: ", a)

#aquí estamos actualizando el valor de un índice particular
t.update(2, 57)

#reemplazará el valor de índice '2' con 25
print("El Arreglo Actualizado es : ", arr)
Пример #6
0
from segment_tree import SegmentTree

a = [18, 17, 13, 19, 15, 11, 20]
st = SegmentTree(a)

assert st.rmq(1, 3) == 2
assert st.rmq(3, 4) == 4
assert st.rmq(0, 0) == 0
assert st.rmq(0, 1) == 1

assert st.rmq(0, 6) == 5
assert st.rmq(4, 6) == 5

st.update(5, 100)

assert st.rmq(1, 3) == 2
assert st.rmq(3, 4) == 4
assert st.rmq(0, 0) == 0
assert st.rmq(0, 1) == 1

assert st.rmq(0, 6) == 2
assert st.rmq(4, 6) == 4
assert st.rmq(4, 5) == 4