def testBuildGraphNoThresholdingNoLSH(self): """All edges whose weight is greater than 0 are retained.""" embeddings = r3_embeddings embedding_path = self._create_embedding_file() write_embeddings(embeddings, embedding_path) graph_path = self._create_graph_file() build_graph_lib.build_graph([embedding_path], graph_path, similarity_threshold=0) g_actual = graph_utils.read_tsv_graph(graph_path) self.assertDictEqual( g_actual, { 'A': { 'B': 0.5, 'C': 0.5 }, 'B': { 'A': 0.5, 'C': 0.5 }, 'C': { 'A': 0.5, 'B': 0.5 } })
def testBuildGraphWithThresholdWithLSHSufficientLSHRounds(self): """Tests the case where we use (multiple rounds of) LSH bucketing.""" # Construct the embeddings and write them to a file. num_points = 20 (embeddings, adjacent_similarity) = self._build_test_embeddings(num_points) embedding_path = self._create_embedding_file() write_embeddings(embeddings, embedding_path) # Build the graph, and read the results into a dictionary. graph_path = self._create_graph_file() build_graph_lib.build_graph([embedding_path], graph_path, similarity_threshold=0.9, lsh_splits=2, lsh_rounds=4, random_seed=12345) g_actual = graph_utils.read_tsv_graph(graph_path) # Constuct the expected graph: each point should be a neighbor of the # point before it and the point after it in the 'embeddings' sequence. # That's because the cosine similarity of adjacent points is ~0.951057, # while between every other point it is ~0.809017 (which is below the # similarity threshold of 0.9). g_expected = {} for node_id in range(num_points): t_dict = g_expected.setdefault('id_{}'.format(node_id), {}) t_dict['id_{}'.format( (node_id - 1) % num_points)] = adjacent_similarity t_dict['id_{}'.format( (node_id + 1) % num_points)] = adjacent_similarity self.assertDictEqual(g_actual, g_expected)
def testBuildGraphWithThresholdWithLSHInsufficientLSHRounds(self): """Tests that some edges are lost with insufficient LSH rounds.""" # Construct the embeddings and write them to a file. num_points = 20 (embeddings, _) = self._build_test_embeddings(num_points) embedding_path = self._create_embedding_file() write_embeddings(embeddings, embedding_path) # Build the graph, and read the results into a dictionary. graph_path = self._create_graph_file() build_graph_lib.build_graph([embedding_path], graph_path, similarity_threshold=0.9, lsh_splits=2, lsh_rounds=1, random_seed=12345) self.assertEqual(self._num_file_lines(graph_path), num_points * 2 - 8) g_actual = graph_utils.read_tsv_graph(graph_path) # Check that the graph contains fewer than 2 * N edges actual_edge_cnt = 0 for (unused_src_id, tgt_dict) in six.iteritems(g_actual): actual_edge_cnt += len(tgt_dict) self.assertEqual(actual_edge_cnt, 2 * len(embeddings) - 8, 'Expected some edges not to have been found.')
def testGraphBuildingWithThresholding(self): """Edges below the similarity threshold are not part of the graph.""" embedding_path = self._create_embedding_file() self._write_embeddings(embedding_path) graph_path = self._create_graph_file() build_graph_lib.build_graph([embedding_path], graph_path, similarity_threshold=0.51) g_actual = graph_utils.read_tsv_graph(graph_path) self.assertDictEqual(g_actual, {})
def testBuildGraphWithThresholdingNoLSH(self): """Edges below the similarity threshold are not part of the graph.""" embeddings = r3_embeddings embedding_path = self._create_embedding_file() write_embeddings(embeddings, embedding_path) graph_path = self._create_graph_file() build_graph_lib.build_graph([embedding_path], graph_path, similarity_threshold=0.51) self.assertEqual(self._num_file_lines(graph_path), 0) g_actual = graph_utils.read_tsv_graph(graph_path) self.assertDictEqual(g_actual, {})
def testBuildGraphInvalidLshRoundsValue(self): with self.assertRaises(ValueError): build_graph_lib.build_graph([], None, lsh_splits=1, lsh_rounds=0)