def test_run(self): FLAGS.projected_dim = None # Make sure we don't test for pre-existing files. self.assertFalse(os.path.isfile(FLAGS.embed_output_dir)) # Run embedding_generator embedding_generator.run(FLAGS) # Make sure that the embedding directory is created. self.assertTrue(os.path.exists(FLAGS.embed_output_dir)) # Make sure that the embedding file is created. expected_embedding_file = os.path.join(FLAGS.embed_output_dir, "emb-00000-of-00001.tfrecords") self.assertTrue(os.path.isfile(expected_embedding_file))
def main(args): """Entry point main function.""" operation = validate_args(args) print("Selected operation: {}".format(operation)) if operation == "generate": print("Generating embeddings...") generator.run(FLAGS) print("Embedding generation completed.") elif operation == "build": print("Building ANN index...") builder.run(FLAGS) print("Building ANN index completed.") elif operation == "e2e": print("Generating embeddings and building ANN index...") generator.run(FLAGS) print("Embedding generation completed.") if FLAGS.projected_dim: FLAGS.dimensions = FLAGS.projected_dim builder.run(FLAGS) print("Building ANN index completed.") else: print("Querying the ANN index...") similarity_finder = finder.load(FLAGS) num_matches = FLAGS.num_matches while True: print("Enter your query: ", end="") query = str(input()) similar_items = similarity_finder.find_similar_items( query, num_matches) print("Results:") print("=========") for item in similar_items: print(item)