def test_segment_tree_update():
    st = SegmentTree(4)
    st.add(0, 1, 2)

    assert st.tree[2].min == 0, "laziness"

    st.add(1, 2, 3)
    assert st.tree[2].min == 2, "lazy update intermedite"
    assert st.tree[5].min == 0, "prop works"
def test_segment_tree_queries():
    INTERFACE_SIZE = 10000
    NUM_QUERIES = 100

    st = SegmentTree(INTERFACE_SIZE)
    nv = Naive(INTERFACE_SIZE)

    for _ in range(NUM_QUERIES):
        left = random.randint(0, INTERFACE_SIZE - 1)
        right = random.randint(left, INTERFACE_SIZE - 1)

        if random.random() > 0.5:
            val = random.randint(1, 100)
            print(f'add {val}: {left} - {right}')
            st.add(left, right, val)
            nv.add(left, right, val)
        else:
            seg_ans = st.min_query(left, right)
            nai_ans = nv.min_query(left, right)
            assert seg_ans == nai_ans, f"{left} {right} correctness"
Exemple #3
0
class ReplayMemory:
    def __init__(self, max_memory=1000):
        self.max_memory = max_memory
        self.memory = SegmentTree(max_memory)
        self._count = 0

    @property
    def count(self):
        return self._count

    def add_memory(self, state_input, best_action, reward, done,
                   next_state_input, td):
        data = [state_input, best_action, reward, done, next_state_input]

        self.memory.add(td, data)

        if self._count <= self.max_memory:
            self._count += 1

    def get_memory(self, batch_size):
        segment = self.memory.total / batch_size

        batch_tree_index = []
        tds = []
        batch = []

        for i in range(batch_size):
            a = segment * i
            b = segment * (i + 1)
            segment = random.uniform(a, b)
            tree_index, td, data = self.memory.get(segment)
            batch_tree_index.append(tree_index)
            tds.append(td)
            batch.append(data)

        return batch_tree_index, tds, batch

    def update_memory(self, tree_indexes, tds):
        for i in range(len(tree_indexes)):
            self.memory.update(tree_indexes[i], tds[i])
def test_segment_tree_prop():
    st = SegmentTree(4)
    st.add(0, 1, 2)

    assert st.tree[2].delta == 2, "lazy update"
    assert len(list(filter(lambda x: x and x.delta == 0, st.tree))) == 6

    st.add(3, 3, 1)
    assert st.tree[7].delta == 1, "lazy update leaf"
    assert len(list(filter(lambda x: x and x.delta == 0, st.tree))) == 5

    st.add(1, 2, 3)
    assert st.tree[2].delta == 0, "lazy update intermedite"
    assert st.tree[5].delta == 5, "prop works"