예제 #1
0
    def test_get_children(self):
        hash_prefix, dim, max_hash_len = '0', 5, 2
        datapoints = np.array([[1.5, 0, 1, 0.5, 0], [1, 1, 0, 0.7, 0.1],
                               [0.8, 0.1, 1, 0.2, 0.4],
                               [0.1, 0.5, 0.3, 0.7, 0.8],
                               [-0.5, 0.1, -0.3, -0.4, 0.2]])
        # Returns children regardless of whether the node should branch. The
        # filtering in the algorithm is done after.
        clustering_param = test_utils.get_test_clustering_param(
            max_depth=max_hash_len)
        projection_vectors = np.array([[0, 1, 1, -1, 0], [1, 0, -1, 0, 0]])
        sh = lsh.SimHash(dim, max_hash_len, projection_vectors)
        node = lsh_tree.LshTreeNode(hash_prefix, datapoints, clustering_param,
                                    sh)
        children = node.children()

        self.assertSameElements([child.hash_prefix for child in children],
                                ['00', '01'])
        for child in children:
            self.assertEqual(child.clustering_param, clustering_param)
            self.assertEqual(child.sim_hash, sh)
            if child.hash_prefix == '00':
                self.assertTrue(
                    (child.nonprivate_points == datapoints[[0, 1]]).all())
            if child.hash_prefix == '01':
                self.assertTrue(
                    (child.nonprivate_points == datapoints[[2, 3, 4]]).all())
예제 #2
0
 def test_group_by_next_hash_shape(self):
     dim, max_hash_len = 10, 6
     num_points = 50
     sh = lsh.SimHash(dim, max_hash_len)
     datapoints = np.random.normal(size=(num_points, dim))
     children = sh.group_by_next_hash(datapoints)
     self.assertEqual(children["0"].shape[0] + children["1"].shape[0],
                      num_points)
예제 #3
0
 def test_value_errors(self):
     dim, max_hash_len = 10, 6
     num_points = 50
     sh = lsh.SimHash(dim, max_hash_len)
     datapoints = np.random.normal(size=(num_points, dim))
     with self.assertRaises(ValueError):
         sh.group_by_next_hash(datapoints, hash_prefix="010010")
     with self.assertRaises(ValueError):
         sh.group_by_next_hash(datapoints, hash_prefix="0101011")
예제 #4
0
 def test_group_by_next_hash(self):
     dim, max_hash_len = 5, 2
     hash_prefix = "0"
     projection_vectors = np.array([[0, 1, 1, -1, 0], [1, 0, -1, 0, 0]])
     sh = lsh.SimHash(dim, max_hash_len, projection_vectors)
     datapoints = np.array([[1.5, 0, 1, 0.5, 0], [1, 1, 0, 0.7, 0.1],
                            [0.8, 0.1, 1, 0.2, 0.4],
                            [0.1, 0.5, 0.3, 0.7, 0.8],
                            [-0.5, 0.1, -0.3, -0.4, 0.2]])
     children = sh.group_by_next_hash(datapoints, hash_prefix)
     self.assertTrue((children["0"] == datapoints[[0, 1]]).all())
     self.assertTrue((children["1"] == datapoints[[2, 3, 4]]).all())
예제 #5
0
    def test_get_children_error(self):
        hash_prefix, dim, max_hash_len = '00', 5, 2
        datapoints = np.array([[1.5, 0, 1, 0.5, 0], [1, 1, 0, 0.7, 0.1]])
        # Returns children regardless of whether the node should branch. The
        # filtering in the algorithm is done after.
        clustering_param = test_utils.get_test_clustering_param(
            max_depth=max_hash_len)
        projection_vectors = np.array([[0, 1, 1, -1, 0], [1, 0, -1, 0, 0]])
        sh = lsh.SimHash(dim, max_hash_len, projection_vectors)
        node = lsh_tree.LshTreeNode(hash_prefix, datapoints, clustering_param,
                                    sh)

        with self.assertRaises(ValueError):
            node.children()
예제 #6
0
def root_node(data: clustering_params.Data,
              clustering_param: clustering_params.ClusteringParam,
              private_count: typing.Optional[int] = None):
    """Returns root node for an LSH prefix tree.

  Args:
    data: Data to use for generating the tree.
    clustering_param: Clustering parameters to use for generating the tree.
    private_count: Private count for the number of datapoints. If None, the
      private count will be computed.
  """
    sim_hash = lsh.SimHash(data.dim, clustering_param.tree_param.max_depth)
    return LshTreeNode("",
                       data.datapoints,
                       clustering_param,
                       sim_hash,
                       private_count=private_count)
예제 #7
0
 def test_projection_vectors_shape(self):
     dim, max_hash_len = 10, 6
     sh = lsh.SimHash(dim, max_hash_len)
     self.assertEqual(sh.projection_vectors.shape, (max_hash_len, dim))
예제 #8
0
def get_test_sim_hash(dim=10, max_hash_len=1):
    """SimHash with defaults for parameters not needed for the test."""
    return lsh.SimHash(dim=dim, max_hash_len=max_hash_len)