def testBuildGraphNoThresholdingNoLSH(self):
     """All edges whose weight is greater than 0 are retained."""
     embeddings = r3_embeddings
     embedding_path = self._create_embedding_file()
     write_embeddings(embeddings, embedding_path)
     graph_path = self._create_graph_file()
     build_graph_lib.build_graph([embedding_path],
                                 graph_path,
                                 similarity_threshold=0)
     g_actual = graph_utils.read_tsv_graph(graph_path)
     self.assertDictEqual(
         g_actual, {
             'A': {
                 'B': 0.5,
                 'C': 0.5
             },
             'B': {
                 'A': 0.5,
                 'C': 0.5
             },
             'C': {
                 'A': 0.5,
                 'B': 0.5
             }
         })
    def testBuildGraphWithThresholdWithLSHSufficientLSHRounds(self):
        """Tests the case where we use (multiple rounds of) LSH bucketing."""
        # Construct the embeddings and write them to a file.
        num_points = 20
        (embeddings,
         adjacent_similarity) = self._build_test_embeddings(num_points)
        embedding_path = self._create_embedding_file()
        write_embeddings(embeddings, embedding_path)

        # Build the graph, and read the results into a dictionary.
        graph_path = self._create_graph_file()
        build_graph_lib.build_graph([embedding_path],
                                    graph_path,
                                    similarity_threshold=0.9,
                                    lsh_splits=2,
                                    lsh_rounds=4,
                                    random_seed=12345)
        g_actual = graph_utils.read_tsv_graph(graph_path)

        # Constuct the expected graph: each point should be a neighbor of the
        # point before it and the point after it in the 'embeddings' sequence.
        # That's because the cosine similarity of adjacent points is ~0.951057,
        # while between every other point it is ~0.809017 (which is below the
        # similarity threshold of 0.9).
        g_expected = {}
        for node_id in range(num_points):
            t_dict = g_expected.setdefault('id_{}'.format(node_id), {})
            t_dict['id_{}'.format(
                (node_id - 1) % num_points)] = adjacent_similarity
            t_dict['id_{}'.format(
                (node_id + 1) % num_points)] = adjacent_similarity
        self.assertDictEqual(g_actual, g_expected)
示例#3
0
    def testBuildGraphWithThresholdWithLSHInsufficientLSHRounds(self):
        """Tests that some edges are lost with insufficient LSH rounds."""
        # Construct the embeddings and write them to a file.
        num_points = 20
        (embeddings, _) = self._build_test_embeddings(num_points)
        embedding_path = self._create_embedding_file()
        write_embeddings(embeddings, embedding_path)

        # Build the graph, and read the results into a dictionary.
        graph_path = self._create_graph_file()
        build_graph_lib.build_graph([embedding_path],
                                    graph_path,
                                    similarity_threshold=0.9,
                                    lsh_splits=2,
                                    lsh_rounds=1,
                                    random_seed=12345)
        self.assertEqual(self._num_file_lines(graph_path), num_points * 2 - 8)
        g_actual = graph_utils.read_tsv_graph(graph_path)

        # Check that the graph contains fewer than 2 * N edges
        actual_edge_cnt = 0
        for (unused_src_id, tgt_dict) in six.iteritems(g_actual):
            actual_edge_cnt += len(tgt_dict)
        self.assertEqual(actual_edge_cnt, 2 * len(embeddings) - 8,
                         'Expected some edges not to have been found.')
示例#4
0
 def testGraphBuildingWithThresholding(self):
   """Edges below the similarity threshold are not part of the graph."""
   embedding_path = self._create_embedding_file()
   self._write_embeddings(embedding_path)
   graph_path = self._create_graph_file()
   build_graph_lib.build_graph([embedding_path],
                               graph_path,
                               similarity_threshold=0.51)
   g_actual = graph_utils.read_tsv_graph(graph_path)
   self.assertDictEqual(g_actual, {})
示例#5
0
 def testBuildGraphWithThresholdingNoLSH(self):
     """Edges below the similarity threshold are not part of the graph."""
     embeddings = r3_embeddings
     embedding_path = self._create_embedding_file()
     write_embeddings(embeddings, embedding_path)
     graph_path = self._create_graph_file()
     build_graph_lib.build_graph([embedding_path],
                                 graph_path,
                                 similarity_threshold=0.51)
     self.assertEqual(self._num_file_lines(graph_path), 0)
     g_actual = graph_utils.read_tsv_graph(graph_path)
     self.assertDictEqual(g_actual, {})
 def testBuildGraphInvalidLshRoundsValue(self):
     with self.assertRaises(ValueError):
         build_graph_lib.build_graph([], None, lsh_splits=1, lsh_rounds=0)