def test_get_children(self): hash_prefix, dim, max_hash_len = '0', 5, 2 datapoints = np.array([[1.5, 0, 1, 0.5, 0], [1, 1, 0, 0.7, 0.1], [0.8, 0.1, 1, 0.2, 0.4], [0.1, 0.5, 0.3, 0.7, 0.8], [-0.5, 0.1, -0.3, -0.4, 0.2]]) # Returns children regardless of whether the node should branch. The # filtering in the algorithm is done after. clustering_param = test_utils.get_test_clustering_param( max_depth=max_hash_len) projection_vectors = np.array([[0, 1, 1, -1, 0], [1, 0, -1, 0, 0]]) sh = lsh.SimHash(dim, max_hash_len, projection_vectors) node = lsh_tree.LshTreeNode(hash_prefix, datapoints, clustering_param, sh) children = node.children() self.assertSameElements([child.hash_prefix for child in children], ['00', '01']) for child in children: self.assertEqual(child.clustering_param, clustering_param) self.assertEqual(child.sim_hash, sh) if child.hash_prefix == '00': self.assertTrue( (child.nonprivate_points == datapoints[[0, 1]]).all()) if child.hash_prefix == '01': self.assertTrue( (child.nonprivate_points == datapoints[[2, 3, 4]]).all())
def test_lsh_tree_negative_count_root_errors(self): test_root = lsh_tree.LshTreeNode( '0', get_test_origin_points(nonprivate_count=15), test_utils.get_test_clustering_param(), get_test_sim_hash(), private_count=-10) with self.assertRaises(ValueError): lsh_tree.LshTree(test_root)
def test_filter_branching_nodes_enough_points(self): sim_hash = get_test_sim_hash() level: lsh_tree.LshTreeLevel = [ lsh_tree.LshTreeNode('0', get_test_origin_points(nonprivate_count=15), test_utils.get_test_clustering_param( min_num_points_in_branching_node=10), sim_hash, private_count=20), ] self.assertSequenceEqual( lsh_tree.LshTree.filter_branching_nodes(level), level)
def test_filter_branching_nodes_too_few_points(self): sim_hash = get_test_sim_hash() # private_count, not the nonprivate_count, should be used for the check. level: lsh_tree.LshTreeLevel = [ lsh_tree.LshTreeNode('0', get_test_origin_points(nonprivate_count=15), test_utils.get_test_clustering_param( min_num_points_in_branching_node=10), sim_hash, private_count=1), ] self.assertEmpty(lsh_tree.LshTree.filter_branching_nodes(level))
def test_get_private_count_basic(self, mock_dlaplace_fn): nonprivate_count = 30 nonprivate_points = get_test_origin_points( nonprivate_count=nonprivate_count) clustering_param = test_utils.get_test_clustering_param( epsilon=5, frac_sum=0.2, frac_group_count=0.8, max_depth=9) sim_hash = get_test_sim_hash() lsh_tree_node = lsh_tree.LshTreeNode( hash_prefix='', nonprivate_points=nonprivate_points, clustering_param=clustering_param, sim_hash=sim_hash) self.assertEqual(lsh_tree_node.get_private_count(), 25) mock_dlaplace_fn.assert_called_once_with(0.4)
def test_get_children_error(self): hash_prefix, dim, max_hash_len = '00', 5, 2 datapoints = np.array([[1.5, 0, 1, 0.5, 0], [1, 1, 0, 0.7, 0.1]]) # Returns children regardless of whether the node should branch. The # filtering in the algorithm is done after. clustering_param = test_utils.get_test_clustering_param( max_depth=max_hash_len) projection_vectors = np.array([[0, 1, 1, -1, 0], [1, 0, -1, 0, 0]]) sh = lsh.SimHash(dim, max_hash_len, projection_vectors) node = lsh_tree.LshTreeNode(hash_prefix, datapoints, clustering_param, sh) with self.assertRaises(ValueError): node.children()
def test_get_private_count_cache(self): nonprivate_count = 30 nonprivate_points = get_test_origin_points( nonprivate_count=nonprivate_count) clustering_param = test_utils.get_test_clustering_param(epsilon=0.01) sim_hash = get_test_sim_hash() lsh_tree_node = lsh_tree.LshTreeNode( hash_prefix='', nonprivate_points=nonprivate_points, clustering_param=clustering_param, sim_hash=sim_hash) first_private_count = lsh_tree_node.get_private_count() self.assertEqual(first_private_count, lsh_tree_node.get_private_count())