Ejemplo n.º 1
0
    def test_db_version_property(self, tmp_path):
        """Tests that a version is correctly added to a new store."""

        store = MoleculeStore(f"{tmp_path}.sqlite")

        with store._get_session() as db:
            db_info = db.query(DBInformation).first()

            assert db_info is not None
            assert db_info.version == DB_VERSION

        assert store.db_version == DB_VERSION
Ejemplo n.º 2
0
    def test_store_bond_order_data(self, tmp_path):

        store = MoleculeStore(f"{tmp_path}.sqlite")

        store.store(
            MoleculeRecord(
                smiles="[Cl:1][H:2]",
                conformers=[
                    ConformerRecord(
                        coordinates=numpy.arange(6).reshape((2, 3)),
                        bond_orders=[
                            WibergBondOrderSet(method="am1", values=[(0, 1, 0.5)])
                        ],
                    )
                ],
            )
        )
        assert len(store) == 1

        with pytest.raises(
            RuntimeError, match=re.escape("am1 WBOs already stored for [Cl:1][H:2]")
        ):
            store.store(
                MoleculeRecord(
                    smiles="[Cl:2][H:1]",
                    conformers=[
                        ConformerRecord(
                            coordinates=numpy.arange(6).reshape((2, 3)),
                            bond_orders=[
                                WibergBondOrderSet(method="am1", values=[(0, 1, 0.5)])
                            ],
                        )
                    ],
                )
            )

        store.store(
            MoleculeRecord(
                smiles="[Cl:2][H:1]",
                conformers=[
                    ConformerRecord(
                        coordinates=numpy.zeros((2, 3)),
                        bond_orders=[
                            WibergBondOrderSet(method="am1", values=[(0, 1, 0.5)])
                        ],
                    )
                ],
            )
        )

        assert len(store) == 1
        assert {*store.wbo_methods} == {"am1"}

        record = store.retrieve()[0]
        assert len(record.conformers) == 2
Ejemplo n.º 3
0
    def test_db_invalid_version(self, tmp_path):
        """Tests that the correct exception is raised when loading a store
        with an unsupported version."""

        store = MoleculeStore(f"{tmp_path}.sqlite")

        with store._get_session() as db:
            db_info = db.query(DBInformation).first()
            db_info.version = DB_VERSION - 1

        with pytest.raises(IncompatibleDBVersion) as error_info:
            MoleculeStore(f"{tmp_path}.sqlite")

        assert error_info.value.found_version == DB_VERSION - 1
        assert error_info.value.expected_version == DB_VERSION
Ejemplo n.º 4
0
    def test_provenance_property(self, tmp_path):
        """Tests that a stores provenance can be set / retrieved."""

        store = MoleculeStore(f"{tmp_path}.sqlite")

        assert store.general_provenance == {}
        assert store.software_provenance == {}

        general_provenance = {"author": "Author 1"}
        software_provenance = {"psi4": "0.1.0"}

        store.set_provenance(general_provenance, software_provenance)

        assert store.general_provenance == general_provenance
        assert store.software_provenance == software_provenance
Ejemplo n.º 5
0
    def _prepare_data_from_path(self, data_paths: List[str]) -> ConcatDataset:

        datasets = []

        for data_path in data_paths:

            extension = os.path.splitext(data_path)[-1].lower()

            if extension == ".sqlite":

                dataset = DGLMoleculeDataset.from_molecule_stores(
                    MoleculeStore(data_path),
                    partial_charge_method=self._partial_charge_method,
                    bond_order_method=self._bond_order_method,
                    atom_features=self._atom_features,
                    bond_features=self._bond_features,
                    molecule_to_dgl=self._molecule_to_dgl,
                )

            else:

                raise NotImplementedError(
                    f"Only paths to SQLite ``MoleculeStore`` databases are supported, and not "
                    f"'{extension}' files.")

            datasets.append(dataset)

        return ConcatDataset(datasets)
Ejemplo n.º 6
0
    def test_match_conformers(self):

        matches = MoleculeStore._match_conformers(
            "[Cl:1][H:2]",
            db_conformers=[
                DBConformerRecord(
                    coordinates=numpy.array([[-1.0, 0.0, 0.0], [1.0, 0.0, 0.0]])
                ),
                DBConformerRecord(
                    coordinates=numpy.array([[-2.0, 0.0, 0.0], [2.0, 0.0, 0.0]])
                ),
            ],
            conformers=[
                ConformerRecord(
                    coordinates=numpy.array([[0.0, -2.0, 0.0], [0.0, 2.0, 0.0]]),
                    partial_charges=[],
                    bond_orders=[],
                ),
                ConformerRecord(
                    coordinates=numpy.array([[0.0, -2.0, 0.0], [0.0, 3.0, 0.0]]),
                    partial_charges=[],
                    bond_orders=[],
                ),
                ConformerRecord(
                    coordinates=numpy.array([[0.0, 0.0, 0.0], [-2.0, 0.0, 0.0]]),
                    partial_charges=[],
                    bond_orders=[],
                ),
            ],
        )

        assert matches == {0: 1, 2: 0}
Ejemplo n.º 7
0
def tmp_molecule_store(tmp_path) -> MoleculeStore:

    store = MoleculeStore(f"{tmp_path}.sqlite")

    expected_records = [
        MoleculeRecord(
            smiles="[Ar:1]",
            conformers=[
                ConformerRecord(
                    coordinates=numpy.array([[0.0, 0.0, 0.0]]),
                    partial_charges=[PartialChargeSet(method="am1", values=[0.5])],
                    bond_orders=[],
                )
            ],
        ),
        MoleculeRecord(
            smiles="[He:1]",
            conformers=[
                ConformerRecord(
                    coordinates=numpy.array([[0.0, 0.0, 0.0]]),
                    partial_charges=[PartialChargeSet(method="am1bcc", values=[-0.5])],
                    bond_orders=[],
                )
            ],
        ),
        MoleculeRecord(
            smiles="[Cl:1][Cl:2]",
            conformers=[
                ConformerRecord(
                    coordinates=numpy.array([[-1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]),
                    partial_charges=[
                        PartialChargeSet(method="am1", values=[0.5, -0.5]),
                        PartialChargeSet(method="am1bcc", values=[0.75, -0.75]),
                    ],
                    bond_orders=[
                        WibergBondOrderSet(method="am1", values=[(0, 1, 1.2)])
                    ],
                )
            ],
        ),
    ]

    store.store(*expected_records)

    return store
Ejemplo n.º 8
0
def test_label_cli(openff_methane: Molecule, runner):

    # Create an SDF file to label.
    openff_methane.to_file("methane.sdf", "sdf")

    arguments = [
        "--input",
        "methane.sdf",
        "--output",
        "labelled.sqlite",
    ]

    result = runner.invoke(label_cli, arguments)

    if result.exit_code != 0:
        raise result.exception

    assert os.path.isfile("labelled.sqlite")

    store = MoleculeStore("labelled.sqlite")

    assert len(store) == 1

    molecule_record = store.retrieve()[0]
    assert (Molecule.from_smiles(
        molecule_record.smiles).to_smiles() == openff_methane.to_smiles())

    assert len(molecule_record.conformers) == 1

    conformer_record = molecule_record.conformers[0]

    assert len(conformer_record.partial_charges) == 2
    assert len(conformer_record.bond_orders) == 1

    for partial_charge_set in conformer_record.partial_charges:

        assert not all(
            numpy.isclose(charge, 0.0) for charge in partial_charge_set.values)

    for bond_order_set in conformer_record.bond_orders:

        assert not all(
            numpy.isclose(value, 0.0)
            for (_, _, value) in bond_order_set.values)
Ejemplo n.º 9
0
def test_data_set_from_molecule_stores(tmpdir):

    molecule_store = MoleculeStore(os.path.join(tmpdir, "store.sqlite"))
    molecule_store.store(
        MoleculeRecord(
            smiles="[Cl:1]-[H:2]",
            conformers=[
                ConformerRecord(
                    coordinates=numpy.array([[-1.0, 0.0, 0.0], [1.0, 0.0,
                                                                0.0]]),
                    partial_charges=[
                        PartialChargeSet(method="am1", values=[0.1, -0.1])
                    ],
                    bond_orders=[
                        WibergBondOrderSet(method="am1", values=[(0, 1, 1.1)])
                    ],
                )
            ],
        ))

    data_set = DGLMoleculeDataset.from_molecule_stores(molecule_store, "am1",
                                                       "am1",
                                                       [AtomConnectivity()],
                                                       [BondIsInRing()])

    assert len(data_set) == 1
    assert data_set.n_features == 4

    dgl_molecule, labels = data_set[0]

    assert isinstance(dgl_molecule, DGLMolecule)
    assert dgl_molecule.n_atoms == 2

    assert "am1-charges" in labels
    assert labels["am1-charges"].numpy().shape == (2, )

    assert "am1-wbo" in labels
    assert labels["am1-wbo"].numpy().shape == (1, )
Ejemplo n.º 10
0
    def mock_data_store(self, tmpdir) -> str:
        store_path = os.path.join(tmpdir, "store.sqlite")

        store = MoleculeStore(store_path)
        store.store(
            MoleculeRecord(
                smiles="[Cl:1][Cl:2]",
                conformers=[
                    ConformerRecord(
                        coordinates=numpy.array([[0.0, 0.0, 0.0],
                                                 [1.0, 0.0, 0.0]]),
                        partial_charges=[
                            PartialChargeSet(method="am1bcc",
                                             values=[1.0, -1.0])
                        ],
                        bond_orders=[
                            WibergBondOrderSet(method="am1",
                                               values=[(0, 1, 1.0)])
                        ],
                    )
                ],
            ))

        return store_path
Ejemplo n.º 11
0
    def test_store_partial_charge_data(self, tmp_path):

        store = MoleculeStore(f"{tmp_path}.sqlite")

        store.store(
            MoleculeRecord(
                smiles="[Cl:1][H:2]",
                conformers=[
                    ConformerRecord(
                        coordinates=numpy.arange(6).reshape((2, 3)),
                        partial_charges=[
                            PartialChargeSet(method="am1", values=[0.50, 1.50])
                        ],
                    )
                ],
            )
        )
        assert len(store) == 1

        store.store(
            MoleculeRecord(
                smiles="[Cl:2][H:1]",
                conformers=[
                    ConformerRecord(
                        coordinates=numpy.flipud(numpy.arange(6).reshape((2, 3))),
                        partial_charges=[
                            PartialChargeSet(method="am1bcc", values=[0.25, 0.75])
                        ],
                    )
                ],
            )
        )

        assert len(store) == 1
        assert {*store.charge_methods} == {"am1", "am1bcc"}

        record = store.retrieve()[0]
        assert len(record.conformers) == 1

        with pytest.raises(
            RuntimeError,
            match=re.escape("am1bcc charges already stored for [Cl:1][H:2]"),
        ):

            store.store(
                MoleculeRecord(
                    smiles="[Cl:2][H:1]",
                    conformers=[
                        ConformerRecord(
                            coordinates=numpy.arange(6).reshape((2, 3)),
                            partial_charges=[
                                PartialChargeSet(method="am1bcc", values=[0.25, 0.75])
                            ],
                        )
                    ],
                )
            )

        assert len(store) == 1
        assert {*store.charge_methods} == {"am1", "am1bcc"}

        record = store.retrieve()[0]
        assert len(record.conformers) == 1
Ejemplo n.º 12
0
def label_cli(
    input_path: str,
    output_path: str,
    guess_stereo: bool,
    rms_cutoff: float,
    worker_type: str,
    n_workers: int,
    batch_size: int,
    lsf_memory: int,
    lsf_walltime: str,
    lsf_queue: str,
    lsf_env: str,
):

    from dask import distributed

    root_logger: logging.Logger = logging.getLogger("nagl")
    root_logger.setLevel(logging.INFO)

    root_handler = logging.StreamHandler()
    root_handler.setFormatter(logging.Formatter("%(message)s"))

    _logger.info("Labeling molecules")

    with capture_toolkit_warnings():

        all_smiles = [
            smiles
            for smiles in tqdm(
                stream_from_file(input_path, as_smiles=True),
                desc="loading molecules",
                ncols=80,
            )
        ]

    unique_smiles = sorted({*all_smiles})

    if len(unique_smiles) != len(all_smiles):

        _logger.warning(
            f"{len(all_smiles) - len(unique_smiles)} duplicate molecules were ignored"
        )

    n_batches = int(math.ceil(len(all_smiles) / batch_size))

    if n_workers < 0:
        n_workers = n_batches

    if n_workers > n_batches:

        _logger.warning(
            f"More workers were requested then there are batches to compute. Only "
            f"{n_batches} workers will be requested."
        )

        n_workers = n_batches

    # Set-up dask to distribute the processing.
    if worker_type == "lsf":
        dask_cluster = setup_dask_lsf_cluster(
            n_workers, lsf_queue, lsf_memory, lsf_walltime, lsf_env
        )
    elif worker_type == "local":
        dask_cluster = setup_dask_local_cluster(n_workers)
    else:
        raise NotImplementedError()

    _logger.info(
        f"{len(unique_smiles)} molecules will labelled in {n_batches} batches across "
        f"{n_workers} workers\n"
    )

    dask_client = distributed.Client(dask_cluster)

    # Submit the tasks to be computed in chunked batches.
    def batch(iterable):
        n_iterables = len(iterable)

        for i in range(0, n_iterables, batch_size):
            yield iterable[i : min(i + batch_size, n_iterables)]

    futures = [
        dask_client.submit(
            functools.partial(
                label_molecules,
                guess_stereochemistry=guess_stereo,
                partial_charge_methods=["am1", "am1bcc"],
                bond_order_methods=["am1"],
                rms_cutoff=rms_cutoff,
            ),
            batched_molecules,
        )
        for batched_molecules in batch(unique_smiles)
    ]

    # Create a database to store the labelled molecules in and store general
    # provenance information.
    storage = MoleculeStore(output_path)

    storage.set_provenance(
        general_provenance={
            "date": datetime.now().strftime("%d-%m-%Y"),
        },
        software_provenance=get_labelling_software_provenance(),
    )

    # Save out the molecules as they are ready.
    error_file_path = output_path.replace(".sqlite", "-errors.log")

    with open(error_file_path, "w") as file:

        for future in tqdm(
            distributed.as_completed(futures, raise_errors=False),
            total=n_batches,
            desc="labelling molecules",
            ncols=80,
        ):

            for molecule_record, error in tqdm(
                future.result(),
                desc="storing batch",
                ncols=80,
            ):

                try:

                    with capture_toolkit_warnings():

                        if molecule_record is not None and error is None:
                            storage.store(molecule_record)

                except BaseException as e:

                    formatted_traceback = traceback.format_exception(
                        etype=type(e), value=e, tb=e.__traceback__
                    )
                    error = f"Could not store record: {formatted_traceback}"

                if error is not None:

                    file.write("=".join(["="] * 40) + "\n")
                    file.write(error + "\n")
                    file.flush()

                    continue

            future.release()

    if worker_type == "lsf":
        dask_cluster.scale(n=0)