def test_lsh_tree_leaves(self): # Test tree: # Nodes are nonprivate count + 1. # Branches to the left are 0, to the right are 1. # Nodes in parentheses are filtered out. # 64+1 # / \ # 8+1 56+1 # / \ / \ # (1+1) 7+1 7+1 49+1 # / \ # (6+1) 43+1 nonprivate_count = 64 sh = get_test_sim_hash() cp = test_utils.get_test_clustering_param( min_num_points_in_node=8, min_num_points_in_branching_node=9, max_depth=3) test_root = TestLshTreeNode('', get_test_origin_points(nonprivate_count), cp, sh, frac_zero=0.125) expected_leaves = [ TestLshTreeNode('01', get_test_origin_points(7), cp, sh), TestLshTreeNode('10', get_test_origin_points(7), cp, sh), TestLshTreeNode('111', get_test_origin_points(43), cp, sh) ] tree = lsh_tree.LshTree(test_root) self.assertEqual(tree.leaves, expected_leaves)
def test_lsh_tree_negative_count_root_errors(self): test_root = lsh_tree.LshTreeNode( '0', get_test_origin_points(nonprivate_count=15), test_utils.get_test_clustering_param(), get_test_sim_hash(), private_count=-10) with self.assertRaises(ValueError): lsh_tree.LshTree(test_root)