def test_db_version_property(self, tmp_path): """Tests that a version is correctly added to a new store.""" store = MoleculeStore(f"{tmp_path}.sqlite") with store._get_session() as db: db_info = db.query(DBInformation).first() assert db_info is not None assert db_info.version == DB_VERSION assert store.db_version == DB_VERSION
def test_store_bond_order_data(self, tmp_path): store = MoleculeStore(f"{tmp_path}.sqlite") store.store( MoleculeRecord( smiles="[Cl:1][H:2]", conformers=[ ConformerRecord( coordinates=numpy.arange(6).reshape((2, 3)), bond_orders=[ WibergBondOrderSet(method="am1", values=[(0, 1, 0.5)]) ], ) ], ) ) assert len(store) == 1 with pytest.raises( RuntimeError, match=re.escape("am1 WBOs already stored for [Cl:1][H:2]") ): store.store( MoleculeRecord( smiles="[Cl:2][H:1]", conformers=[ ConformerRecord( coordinates=numpy.arange(6).reshape((2, 3)), bond_orders=[ WibergBondOrderSet(method="am1", values=[(0, 1, 0.5)]) ], ) ], ) ) store.store( MoleculeRecord( smiles="[Cl:2][H:1]", conformers=[ ConformerRecord( coordinates=numpy.zeros((2, 3)), bond_orders=[ WibergBondOrderSet(method="am1", values=[(0, 1, 0.5)]) ], ) ], ) ) assert len(store) == 1 assert {*store.wbo_methods} == {"am1"} record = store.retrieve()[0] assert len(record.conformers) == 2
def test_db_invalid_version(self, tmp_path): """Tests that the correct exception is raised when loading a store with an unsupported version.""" store = MoleculeStore(f"{tmp_path}.sqlite") with store._get_session() as db: db_info = db.query(DBInformation).first() db_info.version = DB_VERSION - 1 with pytest.raises(IncompatibleDBVersion) as error_info: MoleculeStore(f"{tmp_path}.sqlite") assert error_info.value.found_version == DB_VERSION - 1 assert error_info.value.expected_version == DB_VERSION
def test_provenance_property(self, tmp_path): """Tests that a stores provenance can be set / retrieved.""" store = MoleculeStore(f"{tmp_path}.sqlite") assert store.general_provenance == {} assert store.software_provenance == {} general_provenance = {"author": "Author 1"} software_provenance = {"psi4": "0.1.0"} store.set_provenance(general_provenance, software_provenance) assert store.general_provenance == general_provenance assert store.software_provenance == software_provenance
def _prepare_data_from_path(self, data_paths: List[str]) -> ConcatDataset: datasets = [] for data_path in data_paths: extension = os.path.splitext(data_path)[-1].lower() if extension == ".sqlite": dataset = DGLMoleculeDataset.from_molecule_stores( MoleculeStore(data_path), partial_charge_method=self._partial_charge_method, bond_order_method=self._bond_order_method, atom_features=self._atom_features, bond_features=self._bond_features, molecule_to_dgl=self._molecule_to_dgl, ) else: raise NotImplementedError( f"Only paths to SQLite ``MoleculeStore`` databases are supported, and not " f"'{extension}' files.") datasets.append(dataset) return ConcatDataset(datasets)
def test_match_conformers(self): matches = MoleculeStore._match_conformers( "[Cl:1][H:2]", db_conformers=[ DBConformerRecord( coordinates=numpy.array([[-1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) ), DBConformerRecord( coordinates=numpy.array([[-2.0, 0.0, 0.0], [2.0, 0.0, 0.0]]) ), ], conformers=[ ConformerRecord( coordinates=numpy.array([[0.0, -2.0, 0.0], [0.0, 2.0, 0.0]]), partial_charges=[], bond_orders=[], ), ConformerRecord( coordinates=numpy.array([[0.0, -2.0, 0.0], [0.0, 3.0, 0.0]]), partial_charges=[], bond_orders=[], ), ConformerRecord( coordinates=numpy.array([[0.0, 0.0, 0.0], [-2.0, 0.0, 0.0]]), partial_charges=[], bond_orders=[], ), ], ) assert matches == {0: 1, 2: 0}
def tmp_molecule_store(tmp_path) -> MoleculeStore: store = MoleculeStore(f"{tmp_path}.sqlite") expected_records = [ MoleculeRecord( smiles="[Ar:1]", conformers=[ ConformerRecord( coordinates=numpy.array([[0.0, 0.0, 0.0]]), partial_charges=[PartialChargeSet(method="am1", values=[0.5])], bond_orders=[], ) ], ), MoleculeRecord( smiles="[He:1]", conformers=[ ConformerRecord( coordinates=numpy.array([[0.0, 0.0, 0.0]]), partial_charges=[PartialChargeSet(method="am1bcc", values=[-0.5])], bond_orders=[], ) ], ), MoleculeRecord( smiles="[Cl:1][Cl:2]", conformers=[ ConformerRecord( coordinates=numpy.array([[-1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]), partial_charges=[ PartialChargeSet(method="am1", values=[0.5, -0.5]), PartialChargeSet(method="am1bcc", values=[0.75, -0.75]), ], bond_orders=[ WibergBondOrderSet(method="am1", values=[(0, 1, 1.2)]) ], ) ], ), ] store.store(*expected_records) return store
def test_label_cli(openff_methane: Molecule, runner): # Create an SDF file to label. openff_methane.to_file("methane.sdf", "sdf") arguments = [ "--input", "methane.sdf", "--output", "labelled.sqlite", ] result = runner.invoke(label_cli, arguments) if result.exit_code != 0: raise result.exception assert os.path.isfile("labelled.sqlite") store = MoleculeStore("labelled.sqlite") assert len(store) == 1 molecule_record = store.retrieve()[0] assert (Molecule.from_smiles( molecule_record.smiles).to_smiles() == openff_methane.to_smiles()) assert len(molecule_record.conformers) == 1 conformer_record = molecule_record.conformers[0] assert len(conformer_record.partial_charges) == 2 assert len(conformer_record.bond_orders) == 1 for partial_charge_set in conformer_record.partial_charges: assert not all( numpy.isclose(charge, 0.0) for charge in partial_charge_set.values) for bond_order_set in conformer_record.bond_orders: assert not all( numpy.isclose(value, 0.0) for (_, _, value) in bond_order_set.values)
def test_data_set_from_molecule_stores(tmpdir): molecule_store = MoleculeStore(os.path.join(tmpdir, "store.sqlite")) molecule_store.store( MoleculeRecord( smiles="[Cl:1]-[H:2]", conformers=[ ConformerRecord( coordinates=numpy.array([[-1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]), partial_charges=[ PartialChargeSet(method="am1", values=[0.1, -0.1]) ], bond_orders=[ WibergBondOrderSet(method="am1", values=[(0, 1, 1.1)]) ], ) ], )) data_set = DGLMoleculeDataset.from_molecule_stores(molecule_store, "am1", "am1", [AtomConnectivity()], [BondIsInRing()]) assert len(data_set) == 1 assert data_set.n_features == 4 dgl_molecule, labels = data_set[0] assert isinstance(dgl_molecule, DGLMolecule) assert dgl_molecule.n_atoms == 2 assert "am1-charges" in labels assert labels["am1-charges"].numpy().shape == (2, ) assert "am1-wbo" in labels assert labels["am1-wbo"].numpy().shape == (1, )
def mock_data_store(self, tmpdir) -> str: store_path = os.path.join(tmpdir, "store.sqlite") store = MoleculeStore(store_path) store.store( MoleculeRecord( smiles="[Cl:1][Cl:2]", conformers=[ ConformerRecord( coordinates=numpy.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]), partial_charges=[ PartialChargeSet(method="am1bcc", values=[1.0, -1.0]) ], bond_orders=[ WibergBondOrderSet(method="am1", values=[(0, 1, 1.0)]) ], ) ], )) return store_path
def test_store_partial_charge_data(self, tmp_path): store = MoleculeStore(f"{tmp_path}.sqlite") store.store( MoleculeRecord( smiles="[Cl:1][H:2]", conformers=[ ConformerRecord( coordinates=numpy.arange(6).reshape((2, 3)), partial_charges=[ PartialChargeSet(method="am1", values=[0.50, 1.50]) ], ) ], ) ) assert len(store) == 1 store.store( MoleculeRecord( smiles="[Cl:2][H:1]", conformers=[ ConformerRecord( coordinates=numpy.flipud(numpy.arange(6).reshape((2, 3))), partial_charges=[ PartialChargeSet(method="am1bcc", values=[0.25, 0.75]) ], ) ], ) ) assert len(store) == 1 assert {*store.charge_methods} == {"am1", "am1bcc"} record = store.retrieve()[0] assert len(record.conformers) == 1 with pytest.raises( RuntimeError, match=re.escape("am1bcc charges already stored for [Cl:1][H:2]"), ): store.store( MoleculeRecord( smiles="[Cl:2][H:1]", conformers=[ ConformerRecord( coordinates=numpy.arange(6).reshape((2, 3)), partial_charges=[ PartialChargeSet(method="am1bcc", values=[0.25, 0.75]) ], ) ], ) ) assert len(store) == 1 assert {*store.charge_methods} == {"am1", "am1bcc"} record = store.retrieve()[0] assert len(record.conformers) == 1
def label_cli( input_path: str, output_path: str, guess_stereo: bool, rms_cutoff: float, worker_type: str, n_workers: int, batch_size: int, lsf_memory: int, lsf_walltime: str, lsf_queue: str, lsf_env: str, ): from dask import distributed root_logger: logging.Logger = logging.getLogger("nagl") root_logger.setLevel(logging.INFO) root_handler = logging.StreamHandler() root_handler.setFormatter(logging.Formatter("%(message)s")) _logger.info("Labeling molecules") with capture_toolkit_warnings(): all_smiles = [ smiles for smiles in tqdm( stream_from_file(input_path, as_smiles=True), desc="loading molecules", ncols=80, ) ] unique_smiles = sorted({*all_smiles}) if len(unique_smiles) != len(all_smiles): _logger.warning( f"{len(all_smiles) - len(unique_smiles)} duplicate molecules were ignored" ) n_batches = int(math.ceil(len(all_smiles) / batch_size)) if n_workers < 0: n_workers = n_batches if n_workers > n_batches: _logger.warning( f"More workers were requested then there are batches to compute. Only " f"{n_batches} workers will be requested." ) n_workers = n_batches # Set-up dask to distribute the processing. if worker_type == "lsf": dask_cluster = setup_dask_lsf_cluster( n_workers, lsf_queue, lsf_memory, lsf_walltime, lsf_env ) elif worker_type == "local": dask_cluster = setup_dask_local_cluster(n_workers) else: raise NotImplementedError() _logger.info( f"{len(unique_smiles)} molecules will labelled in {n_batches} batches across " f"{n_workers} workers\n" ) dask_client = distributed.Client(dask_cluster) # Submit the tasks to be computed in chunked batches. def batch(iterable): n_iterables = len(iterable) for i in range(0, n_iterables, batch_size): yield iterable[i : min(i + batch_size, n_iterables)] futures = [ dask_client.submit( functools.partial( label_molecules, guess_stereochemistry=guess_stereo, partial_charge_methods=["am1", "am1bcc"], bond_order_methods=["am1"], rms_cutoff=rms_cutoff, ), batched_molecules, ) for batched_molecules in batch(unique_smiles) ] # Create a database to store the labelled molecules in and store general # provenance information. storage = MoleculeStore(output_path) storage.set_provenance( general_provenance={ "date": datetime.now().strftime("%d-%m-%Y"), }, software_provenance=get_labelling_software_provenance(), ) # Save out the molecules as they are ready. error_file_path = output_path.replace(".sqlite", "-errors.log") with open(error_file_path, "w") as file: for future in tqdm( distributed.as_completed(futures, raise_errors=False), total=n_batches, desc="labelling molecules", ncols=80, ): for molecule_record, error in tqdm( future.result(), desc="storing batch", ncols=80, ): try: with capture_toolkit_warnings(): if molecule_record is not None and error is None: storage.store(molecule_record) except BaseException as e: formatted_traceback = traceback.format_exception( etype=type(e), value=e, tb=e.__traceback__ ) error = f"Could not store record: {formatted_traceback}" if error is not None: file.write("=".join(["="] * 40) + "\n") file.write(error + "\n") file.flush() continue future.release() if worker_type == "lsf": dask_cluster.scale(n=0)