def test_sum_priority(self): sumTree = SumTree(10, [Item("A", 6), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", 9), Item("A", 8), Item("A", 5), Item("A", 8), Item("A", 7), Item("A", -1)]) p = 0 for leaf in sumTree.get_leaves(): p += leaf.get_priority() self.assertAlmostEqual(p, sumTree.get_sum_priority(), places=8)
def test_TD_order(self): items = [Item("A", 6), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", 9), Item("A", 8), Item("A", 5), Item("A", 8), Item("A", 7), Item("A", -1)] sumTree = SumTree(10, items) curr = None for leaf in sumTree.get_leaves(): if curr is None: curr = abs(leaf.TD_error) continue self.assertGreaterEqual(curr, abs(leaf.TD_error)) curr = abs(leaf.TD_error)
def test_remove_add_items(self): items = [Item("A", 6), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", 9), Item("A", 8), Item("A", 5)] sumTree = SumTree(7, items) total = sumTree.get_sum_priority() x,y,z = Item("X", 99), Item("Y", 98), Item("Z", 97) sumTree.add_item(x) sumTree.add_item(y) sumTree.add_item(z) self.assertListEqual(list(sumTree.get_leaves())[:3], [z, y, x])
def test_rank_order(self): sumTree = SumTree(10, [Item("A", 6), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", 9), Item("A", 8), Item("A", 5), Item("A", 8), Item("A", 7), Item("A", -1)]) curr = None for leaf in sumTree.get_leaves(): if curr is None: curr = leaf.get_rank() continue self.assertLess(curr, leaf.get_rank()) curr = leaf.get_rank()
def test_priority_order(self): sumTree = SumTree(10, [Item("A", 6), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", 9), Item("A", 8), Item("A", 5), Item("A", 8), Item("A", 7), Item("A", -1)]) curr = None for leaf in sumTree.get_leaves(): if curr is None: curr = leaf.get_priority() continue self.assertGreaterEqual(curr, leaf.get_priority()) curr = leaf.get_priority()
def test_adding_priority(self): items = [Item("A", 6), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", 9), Item("A", 8), Item("A", 5)] sumTree = SumTree(7, []) for x in items: sumTree.add_item(x) curr = None for leaf in sumTree.get_leaves(): if curr is None: curr = leaf.get_priority() continue self.assertGreaterEqual(curr, leaf.get_priority()) curr = leaf.get_priority()
def test_changing_priority(self): items = [Item("A", 6), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", 9), Item("A", 8), Item("A", 5)] sumTree = SumTree(7, items) total = sumTree.get_sum_priority() x,y,z = Item("X", 99), Item("Y", 98), Item("Z", 97) sumTree.add_item(x) sumTree.add_item(y) sumTree.add_item(z) curr = None for leaf in sumTree.get_leaves(): if curr is None: curr = leaf.get_priority() continue self.assertGreaterEqual(curr, leaf.get_priority()) curr = leaf.get_priority()
def test_tree_sort(self): items = [Item("A", 6), Item("A", 3), Item("A", 999), Item("A", 0), Item("A", 9), Item("A", 8), Item("A", 5), Item("A", None), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", 9), Item("A", 8), Item("A", 5), Item("A", 6), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", 9), Item("A", 68), Item("A", 5), Item("A", 16), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", 9), Item("A", 8), Item("A", 5), Item("A", 6), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", 89), Item("A", 8), Item("A", 5), Item("A", 6), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", 9), Item("A", 8), Item("A", 5), Item("A", 6), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", 9), Item("A", 8), Item("A", 5), Item("A", 6), Item("A", 33), Item("A", 1), Item("A", 70), Item("A", 9), Item("A", 8), Item("A", 5), Item("A", 6), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", 9), Item("A", 8), Item("A", 5), Item("A", 6), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", -89), Item("A", 8), Item("A", 5), Item("A", 6), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", 9), Item("A", 8), Item("A", 5), Item("A", 68), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", 9), Item("A", 8), Item("A", 5), Item("A", 6), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", 9), Item("A", 8), Item("A", 5), Item("A", 6), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", 9), Item("A", 8), Item("A", 5), Item("A", 6), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", 9), Item("A", 8), Item("A", 15), Item("A", 6), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", 9), Item("A", 8), Item("A", 5), Item("A", 6), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", 9), Item("A", 8), Item("A", 15)] sumTree = SumTree(1000000, items) PEReplayMemory.tree = sumTree PER = PEReplayMemory() x = np.zeros((84, 84, 4)) PER.add_experience(0, x, 0, 0, False) PER.tree.add_item(Item("None", None)) PER.tree.add_item(Item("Skipped")) PER.sort_tree() curr = None for leaf in sumTree.get_leaves(): if curr is None: curr = abs(leaf.TD_error) continue self.assertGreaterEqual(curr, abs(leaf.TD_error)) curr = abs(leaf.TD_error)
def test_get_leaves(self): items = [Item("A", 6), Item("A", 3), Item("A", 1), Item("A", 0), Item("A", 9), Item("A", 8), Item("A", 5), Item("A", 8), Item("A", 7), Item("A", -1)] sumTree = SumTree(10, items) self.assertEqual(sumTree.get_leaves(), deque(items, maxlen=10))