def test_segment_tree_basic():
    buff = SegmentTree(2)
    buff.put(Node(-5, 0.1))
    buff.put(Node(2, 1))
    buff.put(Node(1, 1))
    arr = buff._sample_batch(100)
    for a in arr:
        assert a > 0

    for i in range(1, 3):
        assert i in arr
def segment_tree_prob_test():
    buff = SegmentTree(10)
    for i in range(10):
        buff.put(Node(i, 1))
    for i in range(10):
        b = buff._get_prob(i)
        assert abs(b - 0.1) < 1e-6
def test_segment_tree_query():
    buff = SegmentTree(5)
    for i in range(5):
        buff.put(Node(i, i))
    arr = np.array(buff._sample_batch(11000))
    for i in range(5):
        rate = (arr == i).mean()
        assert abs(rate - i / 10) < 2e-2, f'Actual rate = {rate} != {i/10}'
def test_segment_tree_update():
    buff = SegmentTree(5)
    for i in range(5):
        buff.put(i, i)
    for i in range(5):
        assert buff.sum[i + buff.size - 1].dat == i
    buff.set(0, Node(6, 6))
    assert buff.sum[buff.size - 1].dat == 6
def test_segment_tree_1():
    buff = SegmentTree(4)
    buff.put(Node(104, 9))
    buff.put(Node(-3, 0.1))
    buff.put(Node(0, 0.001))
    buff.put(Node(1, 1))
    buff.put(Node(103, 10))
    buff.put(Node(3, 1))
    buff.put(Node(-4, 0.1))
    buff.put(Node(102, 100))
    buff.put(Node(5, 1))
    buff.put(Node(101, 99))
    arr = buff._sample_batch(1000)
    for a in arr:
        assert a > 100

    for i in range(101, 105):
        assert i in arr
def segment_tree_prob_test_1():
    n = 100
    buff = SegmentTree(n)
    data = [random.random() for _ in range(n)]
    for i in range(n):
        buff.put(Node(i, data[i]))
    s = sum(data)
    data = [x / s for x in data]
    for i in range(n):
        assert abs(buff._get_prob(i) - data[i]) < 1e-6
def segment_tree_perf_test():
    n = int(1e5)
    buff = SegmentTree(n)
    for i in range(n):
        buff.put(Node(i, i))
    start = time.time()
    for i in range(20):
        arr = np.array(buff._sample_batch(100))
    used = time.time() - start
    assert used < 0.1
def test_segment_tree_dist_1():
    buff = SegmentTree(5)
    for i in range(5):
        idx = buff.put(Node(i, 1))
        assert idx == i
    for i in range(5):
        buff.set_p(i, i)
    arr = np.array(buff._sample_batch(30000))
    for i in range(5):
        rate = (arr == i).mean()
        assert abs(rate - i / 10) < 1e-2, f'Actual rate = {rate} != {i/10}'
def test_segment_tree_sample_with_slice():
    buff = SegmentTree(100)
    for i in range(5000):
        buff.put(Node(i, i))
    slices, dat, prob = buff.sample_batch(10000)
    for s in slices:
        assert s is not None
        buff.set_p(s, 1)
    slices, dat, prob = buff.sample_batch(10000)
    for s in slices:
        assert s is not None
def test_segment_tree_0():
    buff = SegmentTree(5)
    buff.put(Node(-3, 0.1))
    buff.put(Node(0, 0.001))
    buff.put(Node(1, 1))
    buff.put(Node(2, 1))
    buff.put(Node(3, 1))
    buff.put(Node(-4, 0.1))
    buff.put(Node(4, 1))
    buff.put(Node(5, 1))
    arr = buff._sample_batch(10000)
    for a in arr:
        assert a > 0

    for i in range(1, 6):
        assert i in arr
Esempio n. 11
0
def set_p_test():
    q = IndexPriorityQueue()
    data = []
    for i in range(1000):
        node = QNode(Node(random.random(), p=i), index=i)
        data.append(node)
        q.put_nowait(node)

    for d in data:
        d.dat.p = random.random()

    data.sort()
    for d in data:
        q.set_p(d.index, d.dat.p)

    i = 0
    while not q.empty():
        node = q.get_nowait()
        assert node.dat.dat == data[i].dat.dat
        assert node.dat.p == data[i].dat.p
        i += 1