Ejemplo n.º 1
0
  def test_invalid_graph_file(self):
    filename = tempfile.NamedTemporaryFile(delete=False).name
    with open(filename, 'w') as f:
      for graph in self.graphs:
        graph_io.write_graph(graph, f)

    with self.assertRaisesRegex(AssertionError, 'is not a stats graph'):
      graph_io.get_stats(filename)
def main(argv):
    if len(argv) > 1:
        raise RuntimeError(f'Unexpected arguments: {argv[1:]}')

    input_stats = graph_io.get_stats(FLAGS.in_file)
    max_importance = input_stats['max_final_importance']
    with open(FLAGS.in_file) as input_file:
        rejector = molecule_sampler.RejectToUniform(
            base_iter=graph_io.graph_reader(input_file),
            max_importance=max_importance,
            rng_seed=FLAGS.seed)
        with open(FLAGS.out_file, 'w') as output_file:
            for graph in rejector:
                graph_io.write_graph(graph, output_file)
                if rejector.num_accepted % 10000 == 0:
                    acc = rejector.num_accepted
                    proc = rejector.num_processed
                    print(f'Accepted {acc}/{proc}: {acc / proc * 100:.2f}%')

            output_stats = dict(
                num_samples=rejector.num_accepted,
                estimated_num_graphs=input_stats['estimated_num_graphs'],
                rng_seed=rejector.rng_seed)
            graph_io.write_stats(output_stats, output_file)

    acc = rejector.num_accepted
    proc = rejector.num_processed
    print(f'Done rejecting to uniform! Accepted {acc}/{proc}: '
          f'{acc / proc * 100:.2f}%')
Ejemplo n.º 3
0
def main(argv):
    graph_fnames = argv[1:]

    bucket_sizes, sample_sizes, = [], []
    for graph_fname in graph_fnames:
        stats = graph_io.get_stats(graph_fname)
        bucket_sizes.append(stats['estimated_num_graphs'])
        sample_sizes.append(stats['num_samples'])

    def graph_iter(graph_fname):
        with open(graph_fname) as graph_file:
            for graph in graph_io.graph_reader(graph_file):
                yield graph

    base_iters = (graph_iter(graph_fname) for graph_fname in graph_fnames)
    aggregator = molecule_sampler.AggregateUniformSamples(
        bucket_sizes=bucket_sizes,
        sample_sizes=sample_sizes,
        base_iters=base_iters,
        target_num_samples=FLAGS.target_samples,
        rng_seed=FLAGS.seed)

    with open(FLAGS.output, 'w') as output_file:
        for graph in aggregator:
            graph_io.write_graph(graph, output_file)
            if aggregator.num_accepted % 10000 == 0:
                print(
                    f'Working on file {aggregator.num_iters_started}/'
                    f'{len(graph_fnames)}. Accepted {aggregator.num_accepted}/'
                    f'{aggregator.num_proccessed} so far.')

        stats = dict(target_num_samples=aggregator.target_num_samples,
                     num_samples=aggregator.num_accepted,
                     rng_seed=aggregator.rng_seed,
                     estimated_total_num_graphs=sum(bucket_sizes))
        graph_io.write_stats(stats, output_file)

    acc = aggregator.num_accepted
    proc = aggregator.num_proccessed
    print(f'Done aggregating uniform samples! Accepted {acc}/{proc}: '
          f'{acc / proc * 100:.2f}%')
Ejemplo n.º 4
0
  def test_write_read_graphs(self):
    filename = tempfile.NamedTemporaryFile(delete=False).name

    # Write a file.
    stats = dict(summary='Some summary', quality=100.0)
    with open(filename, 'w') as f:
      for graph in self.graphs:
        graph_io.write_graph(graph, f)
      graph_io.write_stats(stats, f)

    # Check we can recover the data.
    recovered_stats = graph_io.get_stats(filename)
    recovered_graphs = []
    with open(filename, 'r') as f:
      for graph in graph_io.graph_reader(f):
        recovered_graphs.append(graph)

    self.assertEqual(stats, recovered_stats)
    self.assertEqual(len(self.graphs), len(recovered_graphs))
    for g1, g2 in zip(self.graphs, recovered_graphs):
      self.assertTrue(molecule_sampler.is_isomorphic(g1, g2))

    os.remove(filename)
Ejemplo n.º 5
0
def main(argv):
    filenames = argv[1:]
    stats_list = [graph_io.get_stats(filename) for filename in filenames]
    df = pd.DataFrame(stats_list, index=pd.Index(filenames, name='filename'))
    df.to_csv(FLAGS.output)