def test_get_children(self): hash_prefix, dim, max_hash_len = '0', 5, 2 datapoints = np.array([[1.5, 0, 1, 0.5, 0], [1, 1, 0, 0.7, 0.1], [0.8, 0.1, 1, 0.2, 0.4], [0.1, 0.5, 0.3, 0.7, 0.8], [-0.5, 0.1, -0.3, -0.4, 0.2]]) # Returns children regardless of whether the node should branch. The # filtering in the algorithm is done after. clustering_param = test_utils.get_test_clustering_param( max_depth=max_hash_len) projection_vectors = np.array([[0, 1, 1, -1, 0], [1, 0, -1, 0, 0]]) sh = lsh.SimHash(dim, max_hash_len, projection_vectors) node = lsh_tree.LshTreeNode(hash_prefix, datapoints, clustering_param, sh) children = node.children() self.assertSameElements([child.hash_prefix for child in children], ['00', '01']) for child in children: self.assertEqual(child.clustering_param, clustering_param) self.assertEqual(child.sim_hash, sh) if child.hash_prefix == '00': self.assertTrue( (child.nonprivate_points == datapoints[[0, 1]]).all()) if child.hash_prefix == '01': self.assertTrue( (child.nonprivate_points == datapoints[[2, 3, 4]]).all())
def test_group_by_next_hash_shape(self): dim, max_hash_len = 10, 6 num_points = 50 sh = lsh.SimHash(dim, max_hash_len) datapoints = np.random.normal(size=(num_points, dim)) children = sh.group_by_next_hash(datapoints) self.assertEqual(children["0"].shape[0] + children["1"].shape[0], num_points)
def test_value_errors(self): dim, max_hash_len = 10, 6 num_points = 50 sh = lsh.SimHash(dim, max_hash_len) datapoints = np.random.normal(size=(num_points, dim)) with self.assertRaises(ValueError): sh.group_by_next_hash(datapoints, hash_prefix="010010") with self.assertRaises(ValueError): sh.group_by_next_hash(datapoints, hash_prefix="0101011")
def test_group_by_next_hash(self): dim, max_hash_len = 5, 2 hash_prefix = "0" projection_vectors = np.array([[0, 1, 1, -1, 0], [1, 0, -1, 0, 0]]) sh = lsh.SimHash(dim, max_hash_len, projection_vectors) datapoints = np.array([[1.5, 0, 1, 0.5, 0], [1, 1, 0, 0.7, 0.1], [0.8, 0.1, 1, 0.2, 0.4], [0.1, 0.5, 0.3, 0.7, 0.8], [-0.5, 0.1, -0.3, -0.4, 0.2]]) children = sh.group_by_next_hash(datapoints, hash_prefix) self.assertTrue((children["0"] == datapoints[[0, 1]]).all()) self.assertTrue((children["1"] == datapoints[[2, 3, 4]]).all())
def test_get_children_error(self): hash_prefix, dim, max_hash_len = '00', 5, 2 datapoints = np.array([[1.5, 0, 1, 0.5, 0], [1, 1, 0, 0.7, 0.1]]) # Returns children regardless of whether the node should branch. The # filtering in the algorithm is done after. clustering_param = test_utils.get_test_clustering_param( max_depth=max_hash_len) projection_vectors = np.array([[0, 1, 1, -1, 0], [1, 0, -1, 0, 0]]) sh = lsh.SimHash(dim, max_hash_len, projection_vectors) node = lsh_tree.LshTreeNode(hash_prefix, datapoints, clustering_param, sh) with self.assertRaises(ValueError): node.children()
def root_node(data: clustering_params.Data, clustering_param: clustering_params.ClusteringParam, private_count: typing.Optional[int] = None): """Returns root node for an LSH prefix tree. Args: data: Data to use for generating the tree. clustering_param: Clustering parameters to use for generating the tree. private_count: Private count for the number of datapoints. If None, the private count will be computed. """ sim_hash = lsh.SimHash(data.dim, clustering_param.tree_param.max_depth) return LshTreeNode("", data.datapoints, clustering_param, sim_hash, private_count=private_count)
def test_projection_vectors_shape(self): dim, max_hash_len = 10, 6 sh = lsh.SimHash(dim, max_hash_len) self.assertEqual(sh.projection_vectors.shape, (max_hash_len, dim))
def get_test_sim_hash(dim=10, max_hash_len=1): """SimHash with defaults for parameters not needed for the test.""" return lsh.SimHash(dim=dim, max_hash_len=max_hash_len)