def test_ethane(self):
        """The simplest molecule, CC."""
        #     bond_topology = text_format.Parse(
        #         """
        #       atoms: ATOM_C
        #       atoms: ATOM_C
        #       bonds: {
        #         atom_a: 0,
        #         atom_b: 1,
        #         bond_type: BOND_SINGLE
        #       }
        # """, dataset_pb2.BondTopology())
        cc = text_format.Parse("""
      atoms: ATOM_C
      atoms: ATOM_C
""", dataset_pb2.BondTopology())
        scores = np.array([0.1, 1.1, 2.1, 3.1], dtype=np.float32)
        bonds_to_scores = {(0, 1): scores}
        matching_parameters = smu_molecule.MatchingParameters()
        matching_parameters.must_match_all_bonds = False
        mol = smu_molecule.SmuMolecule(cc, bonds_to_scores,
                                       matching_parameters)
        state = mol.generate_search_state()
        self.assertLen(state, 1)
        self.assertEqual(state, [[0, 1, 2, 3]])

        for i, s in enumerate(itertools.product(*state)):
            res = mol.place_bonds(s)
            self.assertIsNotNone(res)
            self.assertAlmostEqual(res.score, scores[i])
    def test_propane_all(self, btype1, btype2, expected_bonds, expected_score):
        cc = text_format.Parse(
            """
      atoms: ATOM_C
      atoms: ATOM_C
      atoms: ATOM_C
""", dataset_pb2.BondTopology())
        #   print(f"Generating bonds {btype1} and {btype2}")
        bonds_to_scores = {
            (0, 1): np.zeros(4, dtype=np.float32),
            (1, 2): np.zeros(4, dtype=np.float32)
        }
        bonds_to_scores[(0, 1)][btype1] = 1.0
        bonds_to_scores[(1, 2)][btype2] = 1.0
        matching_parameters = smu_molecule.MatchingParameters()
        matching_parameters.must_match_all_bonds = False
        mol = smu_molecule.SmuMolecule(cc, bonds_to_scores,
                                       matching_parameters)
        state = mol.generate_search_state()
        for s in itertools.product(*state):
            res = mol.place_bonds(s, matching_parameters)
            if expected_score is not None:
                self.assertIsNotNone(res)
                self.assertLen(res.bonds, expected_bonds)
                self.assertAlmostEqual(res.score, expected_score)
                if btype1 == 0:
                    if btype2 > 0:
                        self.assertEqual(res.bonds[0].bond_type, btype2)
                else:
                    self.assertEqual(res.bonds[0].bond_type, btype1)
                    self.assertEqual(res.bonds[1].bond_type, btype2)
            else:
                self.assertIsNone(res)
    def test_ethane_all(self, btype, expected_bond):
        cc = text_format.Parse("""
      atoms: ATOM_C
      atoms: ATOM_C
""", dataset_pb2.BondTopology())
        bonds_to_scores = {(0, 1): np.zeros(4, dtype=np.float32)}
        bonds_to_scores[(0, 1)][btype] = 1.0
        matching_parameters = smu_molecule.MatchingParameters()
        matching_parameters.must_match_all_bonds = False
        mol = smu_molecule.SmuMolecule(cc, bonds_to_scores,
                                       matching_parameters)
        state = mol.generate_search_state()
        for s in itertools.product(*state):
            res = mol.place_bonds(s, matching_parameters)
            if btype == 0:
                self.assertIsNone(res)
            else:
                self.assertIsNotNone(res)
                self.assertLen(res.bonds, 1)
                self.assertEqual(res.bonds[0].bond_type, expected_bond)
    def test_operators(self):
        cc = text_format.Parse(
            """
      atoms: ATOM_C
      atoms: ATOM_C
      atoms: ATOM_C
""", dataset_pb2.BondTopology())
        #   print(f"Generating bonds {btype1} and {btype2}")
        bonds_to_scores = {
            (0, 1): np.zeros(4, dtype=np.float32),
            (1, 2): np.zeros(4, dtype=np.float32)
        }
        scores = np.array([1.0, 3.0], dtype=np.float32)
        bonds_to_scores[(0, 1)][1] = scores[0]
        bonds_to_scores[(1, 2)][1] = scores[1]
        matching_parameters = smu_molecule.MatchingParameters()
        matching_parameters.must_match_all_bonds = False
        mol = smu_molecule.SmuMolecule(cc, bonds_to_scores,
                                       matching_parameters)
        mol.set_initial_score_and_incrementer(1.0, operator.mul)
        state = mol.generate_search_state()
        for s in itertools.product(*state):
            res = mol.place_bonds(s, matching_parameters)
            self.assertAlmostEqual(res.score, np.product(scores))
Example #5
0
def bond_topologies_from_geom(bond_lengths, conformer_id, fate, bond_topology,
                              geometry, matching_parameters):
    """Return all BondTopology's that are plausible.

    Given a molecule described by `bond_topology` and `geometry`, return all
    possible
    BondTopology that are consistent with that.
    Note that `bond_topology` will be put in a canonical form.

  Args:
    bond_lengths: matrix of interatomic distances
    conformer_id:
    fate: outcome of calculations
    bond_topology:
    geometry: coordinates for the bond_topology
    matching_parameters:

  Returns:
    TopologyMatches
  """
    result = dataset_pb2.TopologyMatches()  # To be returned.
    result.starting_smiles = bond_topology.smiles
    result.conformer_id = conformer_id
    result.fate = fate

    natoms = len(bond_topology.atoms)
    if natoms == 1:
        return result  # empty.

    if len(geometry.atom_positions) != natoms:
        return result  # empty
    distances = utilities.distances(geometry)

    # First join each Hydrogen to its nearest heavy atom, thereby
    # creating a starting BondTopology from which all others can grow
    starting_bond_topology = hydrogen_to_nearest_atom(bond_topology, distances)
    if starting_bond_topology is None:
        return result

    heavy_atom_indices = [
        i for i, t in enumerate(bond_topology.atoms)
        if t != dataset_pb2.BondTopology.AtomType.ATOM_H
    ]

    # For each atom pair, a list of possible bond types.
    # Key is a tuple of the two atom numbers, value is an np.array
    # with the score for each bond type.

    bonds_to_scores: Dict[Tuple[int, int], np.ndarray] = {}
    for (i, j) in itertools.combinations(heavy_atom_indices, 2):  # All pairs.
        dist = distances[i, j]
        if dist > THRESHOLD:
            continue
        try:
            possible_bonds = bond_lengths.probability_of_bond_types(
                bond_topology.atoms[i], bond_topology.atoms[j], dist)
        except KeyError:  # Happens when this bond type has no data
            continue
        if not possible_bonds:
            continue
        # Note that this relies on the fact that BOND_SINGLE==1 etc..
        btypes = np.zeros(4, np.float32)
        for key, value in possible_bonds.items():
            btypes[key] = value
        bonds_to_scores[(i, j)] = btypes

    if not bonds_to_scores:  # Seems unlikely.
        return result

    # Need to know when the starting smiles has been recovered.
    rdkit_mol = smu_utils_lib.bond_topology_to_molecule(bond_topology)
    starting_smiles = smu_utils_lib.compute_smiles_for_molecule(
        rdkit_mol, include_hs=True)
    initial_ring_atom_count = utilities.ring_atom_count_mol(rdkit_mol)

    # Avoid finding duplicates.
    all_found_smiles: Set[str] = set()

    mol = smu_molecule.SmuMolecule(starting_bond_topology, bonds_to_scores,
                                   matching_parameters)

    search_space = mol.generate_search_state()
    for s in itertools.product(*search_space):
        bt = mol.place_bonds(list(s), matching_parameters)
        if not bt:
            continue

        rdkit_mol = smu_utils_lib.bond_topology_to_molecule(bt)
        if matching_parameters.consider_not_bonded and len(
                Chem.GetMolFrags(rdkit_mol)) > 1:
            continue

        found_smiles = smu_utils_lib.compute_smiles_for_molecule(
            rdkit_mol, include_hs=True)
        if found_smiles in all_found_smiles:
            continue

        all_found_smiles.add(found_smiles)

        if matching_parameters.ring_atom_count_cannot_decrease:
            ring_atoms = utilities.ring_atom_count_mol(rdkit_mol)
            if ring_atoms < initial_ring_atom_count:
                continue
            bt.ring_atom_count = ring_atoms

        bt.bond_topology_id = bond_topology.bond_topology_id
        utilities.canonical_bond_topology(bt)

        if found_smiles == starting_smiles:
            bt.is_starting_topology = True

        if not matching_parameters.smiles_with_h:
            found_smiles = smu_utils_lib.compute_smiles_for_molecule(
                rdkit_mol, include_hs=False)

        bt.geometry_score = geometry_score(bt, distances, bond_lengths)
        bt.smiles = found_smiles
        result.bond_topology.append(bt)

    if len(result.bond_topology) > 1:
        result.bond_topology.sort(key=lambda bt: bt.score, reverse=True)

    score_sum = np.sum([bt.score for bt in result.bond_topology])
    for bt in result.bond_topology:
        bt.topology_score = np.log(bt.score / score_sum)
        bt.ClearField("score")

    return result