def main(unused_argv): db = smu_sqlite.SMUSQLite('20220128_standard_v2.sqlite') bond_lengths = bond_length_distribution.AllAtomPairLengthDistributions() bond_lengths.add_from_sparse_dataframe_file( '20220128_bond_lengths.csv', bond_length_distribution.STANDARD_UNBONDED_RIGHT_TAIL_MASS, bond_length_distribution.STANDARD_SIG_DIGITS) fake_smiles_id_dict = collections.defaultdict(lambda: -1) print('molecule_id, count_all, count_smu, count_covalent, count_allen') for molecule in db: if abs(hash(str(molecule.molecule_id))) % 1000 != 1: continue topology_from_geom.standard_topology_sensing(molecule, bond_lengths, fake_smiles_id_dict) count_all = len(molecule.bond_topologies) count_smu = sum(bt.source & dataset_pb2.BondTopology.SOURCE_ITC != 0 for bt in molecule.bond_topologies) count_covalent = sum(bt.source & dataset_pb2.BondTopology.SOURCE_MLCR != 0 for bt in molecule.bond_topologies) count_allen = sum(bt.source & dataset_pb2.BondTopology.SOURCE_CSD != 0 for bt in molecule.bond_topologies) print( f'{molecule.molecule_id}, {count_all}, {count_smu}, {count_covalent}, {count_allen}' )
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') logging.get_absl_handler().use_absl_log_file() logging.info('Opening %s', FLAGS.output_sqlite) db = smu_sqlite.SMUSQLite(FLAGS.output_sqlite, 'c') if FLAGS.bond_topology_csv: logging.info('Starting smiles to btid inserts') smiles_id_dict = smu_utils_lib.smiles_id_dict_from_csv( open(FLAGS.bond_topology_csv)) db.bulk_insert_smiles(smiles_id_dict.items()) logging.info('Finished smiles to btid inserts') else: logging.info('Skipping smiles inserts') logging.info('Starting main inserts') dataset = tf.data.TFRecordDataset(gfile.glob(FLAGS.input_tfrecord)) db.bulk_insert((raw.numpy() for raw in dataset), batch_size=10000) logging.info('Starting vacuuming') db.vacuum() logging.info('Vacuuming finished')
def test_find_by_expanded_stoichiometry_list(self): db = smu_sqlite.SMUSQLite(self.db_filename, 'c') db.bulk_insert( self.encode_molecules( [self.make_fake_molecule(mid) for mid in [2001, 2002, 4004]])) got_mids = [ molecule.molecule_id for molecule in db.find_by_expanded_stoichiometry_list(['(ch2)2(ch3)2']) ] self.assertCountEqual(got_mids, [4004]) got_mids = [ molecule.molecule_id for molecule in db.find_by_expanded_stoichiometry_list(['(ch3)2']) ] self.assertCountEqual(got_mids, [2001, 2002]) got_mids = [ molecule.molecule_id for molecule in db.find_by_expanded_stoichiometry_list(['(ch2)2(ch3)2', '(ch3)2']) ] self.assertCountEqual(got_mids, [2001, 2002, 4004]) self.assertEmpty(list(db.find_by_expanded_stoichiometry_list(['(nh)' ])))
def create_db(self): db = smu_sqlite.SMUSQLite(self.db_filename, 'c') db.bulk_insert( self.encode_conformers([ self.make_fake_conformer(cid) for cid in range(2001, 10001, 2000) ])) return db
def test_repeat_smiles_insert(self): db = smu_sqlite.SMUSQLite(self.db_filename, 'c') db.bulk_insert( self.encode_conformers( [self.make_fake_conformer(cid) for cid in [2001, 2002, 2003]])) got_cids = [ conformer.conformer_id for conformer in db.find_by_smiles('CC') ] self.assertCountEqual(got_cids, [2001, 2002, 2003])
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') logging.info('Opening %s', FLAGS.output_sqlite) db = smu_sqlite.SMUSQLite(FLAGS.output_sqlite, 'c') dataset = tf.data.TFRecordDataset(gfile.glob(FLAGS.input_tfrecord)) db.bulk_insert((raw.numpy() for raw in dataset), batch_size=10000)
def test_repeat_smiles_insert(self): db = smu_sqlite.SMUSQLite(self.db_filename, 'c') db.bulk_insert( self.encode_molecules( [self.make_fake_molecule(mid) for mid in [2001, 2002, 2003]])) got_mids = [ molecule.molecule_id for molecule in db.find_by_smiles_list( ['CC'], smu_utils_lib.WhichTopologies.ALL) ] self.assertCountEqual(got_mids, [2001, 2002, 2003])
def test_no_expanded_stoichiometry_on_ineligible(self): db = smu_sqlite.SMUSQLite(self.db_filename, 'c') conf = self.make_fake_conformer(2001) # This makes the conformer ineligible conf.properties.errors.status = 600 db.bulk_insert(self.encode_conformers([conf])) got_cids = [ conformer.conformer_id for conformer in db.find_by_expanded_stoichiometry('') ] self.assertCountEqual(got_cids, [2001])
def test_read(self): create_db = self.create_db() del create_db db = smu_sqlite.SMUSQLite(self.db_filename, 'r') with self.assertRaises(smu_sqlite.ReadOnlyError): db.bulk_insert([self.make_fake_conformer(9999)]) with self.assertRaises(KeyError): db.find_by_conformer_id(9999) self.assertEqual(db.find_by_conformer_id(4001).conformer_id, 4001)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') logging.get_absl_handler().use_absl_log_file() logging.info('Opening %s', FLAGS.input_sqlite) db = smu_sqlite.SMUSQLite(FLAGS.input_sqlite, 'r') if FLAGS.output_format == OutputFormat.pbtxt: outputter = PBTextOutputter(FLAGS.output_path) elif FLAGS.output_format == OutputFormat.sdf_init: outputter = SDFOutputter( FLAGS.output_path, init_geometry=True, opt_geometry=False) elif FLAGS.output_format == OutputFormat.sdf_opt: outputter = SDFOutputter( FLAGS.output_path, init_geometry=False, opt_geometry=True) elif FLAGS.output_format == OutputFormat.sdf_init_opt: outputter = SDFOutputter( FLAGS.output_path, init_geometry=True, opt_geometry=True) elif FLAGS.output_format == OutputFormat.atomic_input: outputter = AtomicInputOutputter(FLAGS.output_path) elif FLAGS.output_format == OutputFormat.tfdata: outputter = TfDataOutputter(FLAGS.output_path) else: raise ValueError(f'Bad output format {FLAGS.output_format}') if FLAGS.redetect_geometry: outputter = ReDetectTopologiesOutputter(outputter) with contextlib.closing(outputter): for cid in (int(x) for x in FLAGS.cids): conformer = db.find_by_conformer_id(cid) outputter.output(conformer) for btid in (int(x) for x in FLAGS.btids): conformers = db.find_by_bond_topology_id(btid) if not conformers: raise KeyError(f'Bond topology {btid} not found') for c in conformers: outputter.output(c) for smiles in FLAGS.smiles: conformers = db.find_by_smiles(smiles) if not conformers: raise KeyError(f'SMILES {smiles} not found') for c in conformers: outputter.output(c) for smiles in FLAGS.topology_query_smiles: for c in topology_query(db, smiles): outputter.output(c) if FLAGS.random_fraction: for conformer in db: if conformer.fate == dataset_pb2.Conformer.FATE_SUCCESS and random.random( ) < FLAGS.random_fraction: outputter.output(conformer)
def test_read(self): create_db = self.create_db() del create_db db = smu_sqlite.SMUSQLite(self.db_filename, 'r') with self.assertRaises(smu_sqlite.ReadOnlyError): db.bulk_insert( self.encode_molecules([self.make_fake_molecule(9999)])) with self.assertRaises(KeyError): db.find_by_molecule_id(9999) self.assertEqual(db.find_by_molecule_id(4001).molecule_id, 4001)
def test_write(self): create_db = self.create_db() del create_db db = smu_sqlite.SMUSQLite(self.db_filename, 'w') # The create_db makes conformer ids ending in 001. We'll add conformer ids # ending in 005 as the extra written ones to make it clear that they are # different. db.bulk_insert( [self.make_fake_conformer(cid) for cid in range(50005, 60005, 2000)]) # Check an id that was already there self.assertEqual(db.find_by_conformer_id(4001).conformer_id, 4001) # Check an id that we added self.assertEqual(db.find_by_conformer_id(52005).conformer_id, 52005)
def test_find_by_bond_topology_id_source_filtering(self): db = smu_sqlite.SMUSQLite(self.db_filename, 'c') # We'll make 2 molecules # 2001 with bt id 10 (ITC, STARTING) and bt id 11 (MLCR) # 4001 with bt id 10 (ITC), bt id 11 (ITC, STARTING), bt id 12 (CSD) # 6001 with bt id 12 (MLCR) molecules = [] molecules.append(dataset_pb2.Molecule(molecule_id=2001)) self.add_bond_topology_to_molecule( molecules[-1], 10, dataset_pb2.BondTopology.SOURCE_STARTING | dataset_pb2.BondTopology.SOURCE_ITC) self.add_bond_topology_to_molecule( molecules[-1], 11, dataset_pb2.BondTopology.SOURCE_MLCR) molecules.append(dataset_pb2.Molecule(molecule_id=4001)) self.add_bond_topology_to_molecule(molecules[-1], 10, dataset_pb2.BondTopology.SOURCE_ITC) self.add_bond_topology_to_molecule( molecules[-1], 11, dataset_pb2.BondTopology.SOURCE_STARTING | dataset_pb2.BondTopology.SOURCE_ITC) self.add_bond_topology_to_molecule(molecules[-1], 12, dataset_pb2.BondTopology.SOURCE_CSD) molecules.append(dataset_pb2.Molecule(molecule_id=6001)) self.add_bond_topology_to_molecule( molecules[-1], 12, dataset_pb2.BondTopology.SOURCE_MLCR) db.bulk_insert(self.encode_molecules(molecules)) def ids_for(bt_id, which): return [ c.molecule_id for c in db.find_by_bond_topology_id_list([bt_id], which) ] self.assertEqual(ids_for(10, smu_utils_lib.WhichTopologies.ALL), [2001, 4001]) self.assertEqual(ids_for(11, smu_utils_lib.WhichTopologies.ALL), [2001, 4001]) self.assertEqual(ids_for(12, smu_utils_lib.WhichTopologies.ALL), [4001, 6001]) self.assertEqual(ids_for(10, smu_utils_lib.WhichTopologies.STARTING), [2001]) self.assertEqual(ids_for(11, smu_utils_lib.WhichTopologies.MLCR), [2001]) self.assertEqual(ids_for(12, smu_utils_lib.WhichTopologies.CSD), [4001]) self.assertEmpty(ids_for(12, smu_utils_lib.WhichTopologies.ITC)) self.assertEmpty(ids_for(11, smu_utils_lib.WhichTopologies.CSD))
def create_db_with_multiple_bond_topology(self): # We'll set up 3 CIDS with one or more btids associated with them. # cid: 1000 -> btid 1 # cid: 2000 -> btid 2, 1 # cid: 3000 -> btid 3, 1, 2 conf1 = self.make_fake_conformer(1000) conf2 = self.make_fake_conformer(2000) self.add_bond_topology_to_conformer(conf2, 1) conf3 = self.make_fake_conformer(3000) self.add_bond_topology_to_conformer(conf3, 1) self.add_bond_topology_to_conformer(conf3, 2) db = smu_sqlite.SMUSQLite(self.db_filename, 'c') db.bulk_insert(self.encode_conformers([conf1, conf2, conf3])) return db
def test_find_by_expanded_stoichiometry(self): db = smu_sqlite.SMUSQLite(self.db_filename, 'c') db.bulk_insert( self.encode_conformers( [self.make_fake_conformer(cid) for cid in [2001, 2002, 4004]])) got_cids = [ conformer.conformer_id for conformer in db.find_by_expanded_stoichiometry('(ch2)2(ch3)2') ] self.assertCountEqual(got_cids, [4004]) got_cids = [ conformer.conformer_id for conformer in db.find_by_expanded_stoichiometry('(ch3)2') ] self.assertCountEqual(got_cids, [2001, 2002]) self.assertEmpty(list(db.find_by_expanded_stoichiometry('(nh)')))
def create_db_with_multiple_bond_topology(self): # We'll set up 3 CIDS with one or more btids associated with them. # mid: 1000 -> btid 1 # mid: 2000 -> btid 2, 1 # mid: 3000 -> btid 3, 1, 2 mol1 = self.make_fake_molecule(1000) mol2 = self.make_fake_molecule(2000) self.add_bond_topology_to_molecule(mol2, 1, dataset_pb2.BondTopology.SOURCE_ITC) mol3 = self.make_fake_molecule(3000) self.add_bond_topology_to_molecule(mol3, 1, dataset_pb2.BondTopology.SOURCE_ITC) self.add_bond_topology_to_molecule(mol3, 2, dataset_pb2.BondTopology.SOURCE_ITC) db = smu_sqlite.SMUSQLite(self.db_filename, 'c') db.bulk_insert(self.encode_molecules([mol1, mol2, mol3])) return db
def test_find_by_stoichiometry(self): db = smu_sqlite.SMUSQLite(self.db_filename, 'c') db.bulk_insert( self.encode_conformers( [self.make_fake_conformer(cid) for cid in [2001, 2002, 4004]])) got_cids = [ conformer.conformer_id for conformer in db.find_by_stoichiometry('c2h6') ] self.assertCountEqual(got_cids, [2001, 2002]) got_cids = [ conformer.conformer_id for conformer in db.find_by_stoichiometry('c4h10') ] self.assertCountEqual(got_cids, [4004]) self.assertEmpty(list(db.find_by_stoichiometry('c3'))) with self.assertRaises(smu_utils_lib.StoichiometryError): db.find_by_stoichiometry('P3Li')
def test_simple(self): db_filename = os.path.join(tempfile.mkdtemp(), 'query_sqlite_test.sqlite') db = smu_sqlite.SMUSQLite(db_filename, 'w') parser = smu_parser_lib.SmuParser( os.path.join(TESTDATA_PATH, 'pipeline_input_stage2.dat')) db.bulk_insert(x.SerializeToString() for (x, _) in parser.process_stage2()) with flagsaver.flagsaver( bond_lengths_csv=os.path.join(TESTDATA_PATH, 'minmax_bond_distances.csv'), bond_topology_csv=os.path.join(TESTDATA_PATH, 'pipeline_bond_topology.csv')): got = list(query_sqlite.topology_query(db, 'COC(=CF)OC')) # These are just the two conformers that came in with this smiles, so no # interesting detection happened, but it verifies that the code ran without # error. self.assertEqual([c.conformer_id for c in got], [618451001, 618451123]) self.assertLen(got[0].bond_topologies, 1) self.assertEqual(got[0].bond_topologies[0].bond_topology_id, 618451) self.assertLen(got[1].bond_topologies, 1) self.assertEqual(got[1].bond_topologies[0].bond_topology_id, 618451)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') logging.info('Opening %s', FLAGS.input_sqlite) db = smu_sqlite.SMUSQLite(FLAGS.input_sqlite, 'r') if FLAGS.output_format == OutputFormat.pbtxt: outputter = PBTextOutputter(FLAGS.output_path) elif FLAGS.output_format == OutputFormat.sdf_init: outputter = SDFOutputter( FLAGS.output_path, init_geometry=True, opt_geometry=False) elif FLAGS.output_format == OutputFormat.sdf_opt: outputter = SDFOutputter( FLAGS.output_path, init_geometry=False, opt_geometry=True) elif FLAGS.output_format == OutputFormat.sdf_init_opt: outputter = SDFOutputter( FLAGS.output_path, init_geometry=True, opt_geometry=True) else: raise ValueError(f'Bad output format {FLAGS.output_format}') with contextlib.closing(outputter): for cid in (int(x) for x in FLAGS.cids): conformer = db.find_by_conformer_id(cid) outputter.write(conformer) for btid in (int(x) for x in FLAGS.btids): conformers = db.find_by_bond_topology_id(btid) if not conformers: raise KeyError(f'Bond topology {btid} not found') for c in conformers: outputter.write(c) for smiles in FLAGS.smiles: conformers = db.find_by_smiles(smiles) if not conformers: raise KeyError(f'SMILES {smiles} not found') for c in conformers: outputter.write(c)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') logging.get_absl_handler().use_absl_log_file() logging.info('Opening %s', FLAGS.input_sqlite) db = smu_sqlite.SMUSQLite(FLAGS.input_sqlite, 'r') if FLAGS.output_format == OutputFormat.PBTXT: outputter = PBTextOutputter(FLAGS.output_path) elif FLAGS.output_format == OutputFormat.SDF_INIT: outputter = SDFOutputter(FLAGS.output_path, init_geometry=True, opt_geometry=False, which_topologies=FLAGS.which_topologies) elif FLAGS.output_format == OutputFormat.SDF_OPT: outputter = SDFOutputter(FLAGS.output_path, init_geometry=False, opt_geometry=True, which_topologies=FLAGS.which_topologies) elif FLAGS.output_format == OutputFormat.SDF_INIT_OPT: outputter = SDFOutputter(FLAGS.output_path, init_geometry=True, opt_geometry=True, which_topologies=FLAGS.which_topologies) elif FLAGS.output_format == OutputFormat.ATOMIC2_INPUT: outputter = Atomic2InputOutputter(FLAGS.output_path, FLAGS.which_topologies) elif FLAGS.output_format == OutputFormat.DAT: outputter = DatOutputter(FLAGS.output_path) else: raise ValueError(f'Bad output format {FLAGS.output_format}') if FLAGS.redetect_topology: outputter = ReDetectTopologiesOutputter(outputter, db) with contextlib.closing(outputter): for mid in (int(x) for x in FLAGS.mids): molecule = db.find_by_molecule_id(mid) outputter.output(molecule) for c in db.find_by_bond_topology_id_list( [int(x) for x in FLAGS.btids], FLAGS.which_topologies): outputter.output(c) for c in db.find_by_smiles_list(FLAGS.smiles, FLAGS.which_topologies): outputter.output(c) for stoich in FLAGS.stoichiometries: molecules = db.find_by_stoichiometry(stoich) for c in molecules: outputter.output(c) for smiles in FLAGS.topology_query_smiles: geometry_data = GeometryData.get_singleton() for c in db.find_by_topology( smiles, bond_lengths=geometry_data.bond_lengths): outputter.output(c) smarts_query(db, FLAGS.smarts, FLAGS.which_topologies, outputter) if FLAGS.random_fraction: for molecule in db: if random.random() < FLAGS.random_fraction: outputter.output(molecule)
# # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Shows how to check for missing fields.""" from smu import smu_sqlite #----------------------------------------------------------------------------- # Note that we are loading the *complete* database #----------------------------------------------------------------------------- db = smu_sqlite.SMUSQLite('20220128_complete_v2.sqlite') #----------------------------------------------------------------------------- # We'll grab a couple of molecules with different amount of information # stored. #----------------------------------------------------------------------------- PARTIAL_MOLECULE_ID = 35004068 MINIMAL_MOLECULE_ID = 35553043 partial_molecule = db.find_by_molecule_id(PARTIAL_MOLECULE_ID) minimal_molecule = db.find_by_molecule_id(MINIMAL_MOLECULE_ID) print('When you process the *complete* database, you have to be careful to', 'check what data is available') print('We will examine molecules', PARTIAL_MOLECULE_ID, 'and', MINIMAL_MOLECULE_ID)
# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """How to check for missing fields.""" from smu import smu_sqlite # Note that we are loading the *complete* database db = smu_sqlite.SMUSQLite('20220104_complete.sqlite', 'r') # We'll grab a couple of conformers with different amount of information # stored. PARTIAL_CONFORMER_ID = 35004068 MINIMAL_CONFORMER_ID = 35553043 partial_conformer = db.find_by_conformer_id(PARTIAL_CONFORMER_ID) minimal_conformer = db.find_by_conformer_id(MINIMAL_CONFORMER_ID) print('When you process the *complete* database, you have to be careful to', 'check what data is available') print('We will examine conformers', PARTIAL_CONFORMER_ID, 'and', MINIMAL_CONFORMER_ID) print(
# # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """generates a Pandas DataFrame of some field values.""" import pandas as pd from smu import smu_sqlite db = smu_sqlite.SMUSQLite('20220104_standard.sqlite', 'r') count = 0 data_dict = { 'conformer_id': [], 'energy': [], 'h**o': [], 'lumo': [], 'first important frequency': [], } # This iteration will go through all conformers in the database. for conformer in db: data_dict['conformer_id'].append(conformer.conformer_id) data_dict['energy'].append( conformer.properties.single_point_energy_atomic_b5.value)
# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Basic access to different kinds of fields.""" from smu import smu_sqlite db = smu_sqlite.SMUSQLite('20220128_standard_v2.sqlite') #----------------------------------------------------------------------------- # This is an arbitrary choice of the molecule to use. #----------------------------------------------------------------------------- molecule = db.find_by_molecule_id(57001) print('We will examine molecule with id', molecule.molecule_id) print('The computed properties are generally in the .properties field') print('Scalar values are access by name (note the .value suffix),', 'like this single point energy: ', molecule.properties.single_point_energy_atomic_b5.value) print('Fields with repeated values',
def test_find_by_topology(self): db = smu_sqlite.SMUSQLite(self.db_filename, 'c') # We'll make a pretty fake molecule. N2O2H2 with # the O at 0,0 # the Ns at 1.1,0 and 0,1.1 # The Hs right night to the Ns # We'll given it the ring topology to start and the symetric ring broken # topologies should be found. molecule = dataset_pb2.Molecule(molecule_id=9999) molecule.properties.errors.fate = dataset_pb2.Properties.FATE_SUCCESS bt = molecule.bond_topologies.add(smiles='N1NO1', bond_topology_id=100) geom = molecule.optimized_geometry.atom_positions bt.atoms.append(dataset_pb2.BondTopology.ATOM_O) geom.append(dataset_pb2.Geometry.AtomPos(x=0, y=0, z=0)) bt.atoms.append(dataset_pb2.BondTopology.ATOM_N) geom.append(dataset_pb2.Geometry.AtomPos(x=0, y=1.1, z=0)) bt.bonds.append( dataset_pb2.BondTopology.Bond( atom_a=0, atom_b=1, bond_type=dataset_pb2.BondTopology.BOND_SINGLE)) bt.atoms.append(dataset_pb2.BondTopology.ATOM_N) geom.append(dataset_pb2.Geometry.AtomPos(x=1.1, y=0, z=0)) bt.bonds.append( dataset_pb2.BondTopology.Bond( atom_a=0, atom_b=2, bond_type=dataset_pb2.BondTopology.BOND_SINGLE)) bt.bonds.append( dataset_pb2.BondTopology.Bond( atom_a=1, atom_b=2, bond_type=dataset_pb2.BondTopology.BOND_SINGLE)) bt.atoms.append(dataset_pb2.BondTopology.ATOM_H) geom.append(dataset_pb2.Geometry.AtomPos(x=0, y=1.2, z=0)) bt.bonds.append( dataset_pb2.BondTopology.Bond( atom_a=1, atom_b=3, bond_type=dataset_pb2.BondTopology.BOND_SINGLE)) bt.atoms.append(dataset_pb2.BondTopology.ATOM_H) geom.append(dataset_pb2.Geometry.AtomPos(x=1.2, y=0, z=0)) bt.bonds.append( dataset_pb2.BondTopology.Bond( atom_a=2, atom_b=4, bond_type=dataset_pb2.BondTopology.BOND_SINGLE)) for pos in geom: pos.x /= smu_utils_lib.BOHR_TO_ANGSTROMS pos.y /= smu_utils_lib.BOHR_TO_ANGSTROMS pos.z /= smu_utils_lib.BOHR_TO_ANGSTROMS db.bulk_insert([molecule.SerializeToString()]) db.bulk_insert_smiles([['N1NO1', 100], ['N=[NH+][O-]', 101]]) bond_lengths = bond_length_distribution.make_fake_empiricals() # We'll query by the topology that was in the DB then the one that wasn't for query_smiles in ['N1NO1', 'N=[NH+][O-]']: got = list( db.find_by_topology(query_smiles, bond_lengths=bond_lengths)) self.assertLen(got, 1) self.assertCountEqual( [100, 101, 101], [bt.bond_topology_id for bt in got[0].bond_topologies])