def try_roundtrip(self, filename, stage):
    parser = smu_parser_lib.SmuParser(
        os.path.join(TESTDATA_PATH, filename))
    writer = smu_writer_lib.SmuWriter(annotate=False)

    if stage == 'stage1':
      process_fn = parser.process_stage1
      writer_fn = writer.process_stage1_proto
    elif stage == 'stage2':
      process_fn = parser.process_stage2
      writer_fn = writer.process_stage2_proto
    else:
      raise ValueError(stage)

    for maybe_conformer, orig_contents in process_fn():
      if isinstance(maybe_conformer, Exception):
        raise maybe_conformer
      self.assertGreater(maybe_conformer.bond_topologies[0].bond_topology_id,
                         0)
      smu_writer_lib.check_dat_formats_match(
          orig_contents,
          writer_fn(maybe_conformer).splitlines())
Example #2
0
def parse_dat_file(filename, stage):
    """Beam pipeline component for reading dat files.

  Args:
    filename: filename to read
    stage: string 'stage1' or 'stage2'

  Yields:
    Pair of string (original dat), conformer
    conformer can be an Exception or a dataset_pb2.Conformer
  """
    smu_parser = smu_parser_lib.SmuParser(filename)
    if stage == 'stage1':
        process_fn = smu_parser.process_stage1
    else:
        process_fn = smu_parser.process_stage2
    for conformer, orig_dat_list in process_fn():
        orig_dat = '\n'.join(orig_dat_list) + '\n'

        beam.metrics.Metrics.counter(_METRICS_NAMESPACE,
                                     stage + '_dat_entry_read').inc()

        yield orig_dat, conformer
Example #3
0
    def test_simple(self):
        db_filename = os.path.join(tempfile.mkdtemp(),
                                   'query_sqlite_test.sqlite')
        db = smu_sqlite.SMUSQLite(db_filename, 'w')
        parser = smu_parser_lib.SmuParser(
            os.path.join(TESTDATA_PATH, 'pipeline_input_stage2.dat'))
        db.bulk_insert(x.SerializeToString()
                       for (x, _) in parser.process_stage2())

        with flagsaver.flagsaver(
                bond_lengths_csv=os.path.join(TESTDATA_PATH,
                                              'minmax_bond_distances.csv'),
                bond_topology_csv=os.path.join(TESTDATA_PATH,
                                               'pipeline_bond_topology.csv')):
            got = list(query_sqlite.topology_query(db, 'COC(=CF)OC'))

        # These are just the two conformers that came in with this smiles, so no
        # interesting detection happened, but it verifies that the code ran without
        # error.
        self.assertEqual([c.conformer_id for c in got], [618451001, 618451123])
        self.assertLen(got[0].bond_topologies, 1)
        self.assertEqual(got[0].bond_topologies[0].bond_topology_id, 618451)
        self.assertLen(got[1].bond_topologies, 1)
        self.assertEqual(got[1].bond_topologies[0].bond_topology_id, 618451)
 def setUp(self):
     super().setUp()
     self.parser = smu_parser_lib.SmuParser(
         os.path.join(TESTDATA_PATH, MAIN_DAT_FILE))
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    smu_writer = smu_writer_lib.SmuWriter(annotate=False)

    # output_files maps from Outcome to the a pair of file handle
    output_files = {}
    output_files[Outcome.SUCCESS] = (gfile.GFile(
        FLAGS.output_stem + '_success_original.dat',
        'w'), gfile.GFile(FLAGS.output_stem + '_success_regen.dat', 'w'))
    output_files[Outcome.MISMATCH] = (gfile.GFile(
        FLAGS.output_stem + '_mismatch_original.dat',
        'w'), gfile.GFile(FLAGS.output_stem + '_mismatch_regen.dat', 'w'))
    output_files[Outcome.PARSE_ERROR_KNOWN] = (
        gfile.GFile(FLAGS.output_stem + '_parse_error_known_original.dat',
                    'w'),
        gfile.GFile(FLAGS.output_stem + '_parse_error_known_regen.dat', 'w'))
    output_files[Outcome.PARSE_ERROR_UNKNOWN] = (
        gfile.GFile(FLAGS.output_stem + '_parse_error_unknown_original.dat',
                    'w'),
        gfile.GFile(FLAGS.output_stem + '_parse_error_unknown_regen.dat', 'w'))

    file_count = 0
    conformer_count = 0
    outcome_counts = collections.Counter()

    for filepath in gfile.glob(FLAGS.input_glob):
        logging.info('Processing file %s', filepath)
        file_count += 1
        smu_parser = smu_parser_lib.SmuParser(filepath)
        if FLAGS.stage == 'stage1':
            process_fn = smu_parser.process_stage1
        else:
            process_fn = smu_parser.process_stage2
        for conformer, orig_contents_list in process_fn():
            conformer_count += 1

            outcome = None

            if isinstance(conformer, Exception):
                if isinstance(conformer, smu_parser_lib.SmuKnownError):
                    outcome = Outcome.PARSE_ERROR_KNOWN
                else:
                    outcome = Outcome.PARSE_ERROR_UNKNOWN
                regen_contents = '{}\n{}: {} {}\n'.format(
                    smu_parser_lib.SEPARATOR_LINE, conformer.conformer_id,
                    type(conformer).__name__, str(conformer))
            else:
                if FLAGS.stage == 'stage1':
                    regen_contents = smu_writer.process_stage1_proto(conformer)
                else:
                    regen_contents = smu_writer.process_stage2_proto(conformer)
                try:
                    smu_writer_lib.check_dat_formats_match(
                        orig_contents_list, regen_contents.splitlines())
                    outcome = Outcome.SUCCESS
                except smu_writer_lib.DatFormatMismatchError as e:
                    outcome = Outcome.MISMATCH
                    print(e)

            outcome_counts[outcome] += 1
            output_files[outcome][0].write('\n'.join(orig_contents_list) +
                                           '\n')
            output_files[outcome][1].write(regen_contents)

    for file_orig, file_regen in output_files.values():
        file_orig.close()
        file_regen.close()

    def outcome_status(outcome):
        if conformer_count:
            percent = outcome_counts[outcome] / conformer_count * 100
        else:
            percent = float('nan')
        return '%5.1f%% %7d %s \n' % (percent, outcome_counts[outcome],
                                      str(outcome))

    status_str = ('COMPLETE: Read %d files, %d conformers\n' %
                  (file_count, conformer_count) +
                  outcome_status(Outcome.SUCCESS) +
                  outcome_status(Outcome.PARSE_ERROR_KNOWN) +
                  outcome_status(Outcome.MISMATCH) +
                  outcome_status(Outcome.PARSE_ERROR_UNKNOWN))

    logging.info(status_str)
    print(status_str)
Example #6
0
def get_stage2_conformer():
  parser = smu_parser_lib.SmuParser(os.path.join(TESTDATA_PATH, MAIN_DAT_FILE))
  conformer, _ = next(parser.process_stage2())
  return conformer