def _main(argv):
    """Main function for running the build_graph program."""
    flag = flags.FLAGS
    flag.showprefixforinfo = False
    if len(argv) < 3:
        raise app.UsageError(
            'Invalid number of arguments; expected 2 or more, got %d' %
            (len(argv) - 1))

    embeddings = _read_tfrecord_examples(argv[1:-1], flag.id_feature_name,
                                         flag.embedding_feature_name)
    graph = collections.defaultdict(dict)
    _add_edges(embeddings, flag.similarity_threshold, graph)
    graph_utils.write_tsv_graph(argv[-1], graph)
Ejemplo n.º 2
0
def build_graph(embedding_files,
                output_graph_path,
                similarity_threshold=0.8,
                id_feature_name='id',
                embedding_feature_name='embedding'):
    """Builds a graph based on dense embeddings and persists it in TSV format.

  This function reads input instances from one or more TFRecord files, each
  containing `tf.train.Example` protos. Each input example is expected to
  contain at least the following 2 features:

  *   `id`: A singleton `bytes_list` feature that identifies each example.
  *   `embedding`: A `float_list` feature that contains the (dense) embedding of
       each example.

  `id` and `embedding` are not necessarily the literal feature names; if your
  features have different names, you can specify them using the
  `id_feature_name` and `embedding_feature_name` arguments, respectively.

  This function then computes the cosine similarity between all pairs of input
  examples based on their associated embeddings. An edge is written to the TSV
  file named by `output_graph_path` for each pair whose similarity is at least
  as large as `similarity_threshold`. Each output edge is represented by a TSV
  line in the `output_graph_path` file with the following form:

  ```
  source_id<TAB>target_id<TAB>edge_weight
  ```

  All edges in the output will be symmetric (i.e., if edge `A--w-->B` exists in
  the output, then so will edge `B--w-->A`).

  Args:
    embedding_files: A list of names of TFRecord files containing
      `tf.train.Example` objects, which in turn contain dense embeddings.
    output_graph_path: Name of the file to which the output graph in TSV format
      should be written.
    similarity_threshold: Threshold used to determine which edges to retain in
      the resulting graph.
    id_feature_name: The name of the feature in the input `tf.train.Example`
      objects representing the ID of examples.
    embedding_feature_name: The name of the feature in the input
      `tf.train.Example` objects representing the embedding of examples.
  """
    embeddings = _read_tfrecord_examples(embedding_files, id_feature_name,
                                         embedding_feature_name)
    graph = collections.defaultdict(dict)
    _add_edges(embeddings, similarity_threshold, graph)
    graph_utils.write_tsv_graph(output_graph_path, graph)
Ejemplo n.º 3
0
 def setUp(self):
     super(PackNbrsTest, self).setUp()
     # Write graph edges (as a TSV file).
     self._graph_path = self._create_tmp_file('graph.tsv')
     graph_utils.write_tsv_graph(self._graph_path, _GRAPH)
     # Write labeled training Examples.
     self._training_examples_path = self._create_tmp_file('train_data.tfr')
     _write_examples(self._training_examples_path,
                     [_example_a(), _example_c()])
     # Write unlabeled neighbor Examples.
     self._neighbor_examples_path = self._create_tmp_file(
         'neighbor_data.tfr')
     _write_examples(self._neighbor_examples_path, [_example_b()])
     # Create output file
     self._output_nsl_training_data_path = self._create_tmp_file(
         'nsl_train_data.tfr')
Ejemplo n.º 4
0
 def testReadAndWriteTsvGraph(self):
     path = self.create_tempfile('graph.tsv').full_path
     graph_utils.write_tsv_graph(path, GRAPH)
     read_graph = graph_utils.read_tsv_graph(path)
     self.assertDictEqual(read_graph, GRAPH)