Example #1
0
    def test_get_children(self):
        hash_prefix, dim, max_hash_len = '0', 5, 2
        datapoints = np.array([[1.5, 0, 1, 0.5, 0], [1, 1, 0, 0.7, 0.1],
                               [0.8, 0.1, 1, 0.2, 0.4],
                               [0.1, 0.5, 0.3, 0.7, 0.8],
                               [-0.5, 0.1, -0.3, -0.4, 0.2]])
        # Returns children regardless of whether the node should branch. The
        # filtering in the algorithm is done after.
        clustering_param = test_utils.get_test_clustering_param(
            max_depth=max_hash_len)
        projection_vectors = np.array([[0, 1, 1, -1, 0], [1, 0, -1, 0, 0]])
        sh = lsh.SimHash(dim, max_hash_len, projection_vectors)
        node = lsh_tree.LshTreeNode(hash_prefix, datapoints, clustering_param,
                                    sh)
        children = node.children()

        self.assertSameElements([child.hash_prefix for child in children],
                                ['00', '01'])
        for child in children:
            self.assertEqual(child.clustering_param, clustering_param)
            self.assertEqual(child.sim_hash, sh)
            if child.hash_prefix == '00':
                self.assertTrue(
                    (child.nonprivate_points == datapoints[[0, 1]]).all())
            if child.hash_prefix == '01':
                self.assertTrue(
                    (child.nonprivate_points == datapoints[[2, 3, 4]]).all())
Example #2
0
 def test_lsh_tree_negative_count_root_errors(self):
     test_root = lsh_tree.LshTreeNode(
         '0',
         get_test_origin_points(nonprivate_count=15),
         test_utils.get_test_clustering_param(),
         get_test_sim_hash(),
         private_count=-10)
     with self.assertRaises(ValueError):
         lsh_tree.LshTree(test_root)
Example #3
0
 def test_filter_branching_nodes_enough_points(self):
     sim_hash = get_test_sim_hash()
     level: lsh_tree.LshTreeLevel = [
         lsh_tree.LshTreeNode('0',
                              get_test_origin_points(nonprivate_count=15),
                              test_utils.get_test_clustering_param(
                                  min_num_points_in_branching_node=10),
                              sim_hash,
                              private_count=20),
     ]
     self.assertSequenceEqual(
         lsh_tree.LshTree.filter_branching_nodes(level), level)
Example #4
0
 def test_filter_branching_nodes_too_few_points(self):
     sim_hash = get_test_sim_hash()
     # private_count, not the nonprivate_count, should be used for the check.
     level: lsh_tree.LshTreeLevel = [
         lsh_tree.LshTreeNode('0',
                              get_test_origin_points(nonprivate_count=15),
                              test_utils.get_test_clustering_param(
                                  min_num_points_in_branching_node=10),
                              sim_hash,
                              private_count=1),
     ]
     self.assertEmpty(lsh_tree.LshTree.filter_branching_nodes(level))
Example #5
0
 def test_get_private_count_basic(self, mock_dlaplace_fn):
     nonprivate_count = 30
     nonprivate_points = get_test_origin_points(
         nonprivate_count=nonprivate_count)
     clustering_param = test_utils.get_test_clustering_param(
         epsilon=5, frac_sum=0.2, frac_group_count=0.8, max_depth=9)
     sim_hash = get_test_sim_hash()
     lsh_tree_node = lsh_tree.LshTreeNode(
         hash_prefix='',
         nonprivate_points=nonprivate_points,
         clustering_param=clustering_param,
         sim_hash=sim_hash)
     self.assertEqual(lsh_tree_node.get_private_count(), 25)
     mock_dlaplace_fn.assert_called_once_with(0.4)
Example #6
0
    def test_get_children_error(self):
        hash_prefix, dim, max_hash_len = '00', 5, 2
        datapoints = np.array([[1.5, 0, 1, 0.5, 0], [1, 1, 0, 0.7, 0.1]])
        # Returns children regardless of whether the node should branch. The
        # filtering in the algorithm is done after.
        clustering_param = test_utils.get_test_clustering_param(
            max_depth=max_hash_len)
        projection_vectors = np.array([[0, 1, 1, -1, 0], [1, 0, -1, 0, 0]])
        sh = lsh.SimHash(dim, max_hash_len, projection_vectors)
        node = lsh_tree.LshTreeNode(hash_prefix, datapoints, clustering_param,
                                    sh)

        with self.assertRaises(ValueError):
            node.children()
Example #7
0
    def test_get_private_count_cache(self):
        nonprivate_count = 30
        nonprivate_points = get_test_origin_points(
            nonprivate_count=nonprivate_count)
        clustering_param = test_utils.get_test_clustering_param(epsilon=0.01)
        sim_hash = get_test_sim_hash()
        lsh_tree_node = lsh_tree.LshTreeNode(
            hash_prefix='',
            nonprivate_points=nonprivate_points,
            clustering_param=clustering_param,
            sim_hash=sim_hash)

        first_private_count = lsh_tree_node.get_private_count()
        self.assertEqual(first_private_count,
                         lsh_tree_node.get_private_count())