Пример #1
0
def mark_reactants(source_mol: Mol, target_mol: Mol):
    target_atoms = set(a.GetAtomMapNum()
                       for a in reversed(target_mol.GetAtoms()))
    for a in source_mol.GetAtoms():
        m = a.GetAtomMapNum()
        if m in target_atoms:
            a.SetBoolProp('in_target', True)
Пример #2
0
def fix_incomplete_mappings(sub_mol: Mol, prod_mol: Mol) -> Tuple[Mol, Mol]:
    max_map = max(a.GetAtomMapNum() for a in sub_mol.GetAtoms())
    max_map = max(max(a.GetAtomMapNum() for a in prod_mol.GetAtoms()), max_map)

    for mol in (sub_mol, prod_mol):
        for a in mol.GetAtoms():
            map_num = a.GetAtomMapNum()
            if map_num is None or map_num < 1:
                max_map += 1
                a.SetAtomMapNum(max_map)
    return sub_mol, prod_mol
Пример #3
0
def update_feat_values(mol: Mol, atom_props: dict, bond_props: dict):
    for atom in mol.GetAtoms():
        for prop_key in atom_props.keys():
            atom_props[prop_key].add(try_get_atom_feature(atom, prop_key))

    for bond in mol.GetBonds():
        for prop_key in bond_props.keys():
            bond_props[prop_key].add(try_get_bond_feature(bond, prop_key))
Пример #4
0
def fix_explicit_hs(mol: Mol) -> Mol:
    for a in mol.GetAtoms():
        a.SetNoImplicit(False)

    mol = Chem.AddHs(mol, explicitOnly=True)
    mol = Chem.RemoveHs(mol)

    Chem.SanitizeMol(mol)
    return mol
Пример #5
0
def find_added_benzene_rings(source_mol: Mol,
                             target_mol: Mol) -> List[List[int]]:
    """
    Find benzene rings that were added in the process of reaction generation
    """
    target_rings = find_rings(target_mol)

    map2atom = dict(
        (a.GetAtomMapNum(), a) for i, a in enumerate(target_mol.GetAtoms()))
    source_atoms = set(a.GetAtomMapNum() for a in source_mol.GetAtoms())

    added_benzene_rings = []
    for ring in target_rings:
        if all(m not in source_atoms
               for m in ring) and is_benzene_ring([map2atom[m] for m in ring]):
            added_benzene_rings.append(ring)

    return added_benzene_rings
Пример #6
0
def filter_reactants(sub_mols: List[Mol], prod_mol: Mol) -> Mol:
    mol_maps = set(a.GetAtomMapNum() for a in prod_mol.GetAtoms())
    reactants = []
    for mol in sub_mols:
        for a in mol.GetAtoms():
            if a.GetAtomMapNum() in mol_maps:
                reactants.append(mol)
                break
    return Chem.MolFromSmiles('.'.join(
        [Chem.MolToSmiles(m) for m in reactants]))
Пример #7
0
def add_map_numbers(mol: Mol) -> Mol:
    # converting to smiles to mol and again to smiles makes atom order canonical
    mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))

    map_nums = np.arange(mol.GetNumAtoms()) + 1
    np.random.shuffle(map_nums)

    for i, a in enumerate(mol.GetAtoms()):
        a.SetAtomMapNum(int(map_nums[i]))
    return mol
Пример #8
0
def add_benzene_ring(mol: Mol, start_atom_ind: int, ring_atom_maps: List[int]):
    new_atom_ind = []

    map2i = dict((a.GetAtomMapNum(), i) for i, a in enumerate(mol.GetAtoms()))

    start_atom = mol.GetAtomWithIdx(start_atom_ind)
    start_atom.SetBoolProp('is_edited', True)
    start_atom.SetIsAromatic(True)
    start_atom_map = start_atom.GetAtomMapNum()

    if start_atom.HasProp('in_reactant'):
        in_reactant = start_atom.GetBoolProp('in_reactant')
    else:
        in_reactant = False

    if start_atom.HasProp('mol_id'):
        mol_id = start_atom.GetIntProp('mol_id')
    else:
        mol_id = 1

    for atom_map in ring_atom_maps:
        if atom_map != start_atom_map:
            if atom_map in map2i:
                new_atom_ind.append(map2i[atom_map])
            else:
                num_atoms = mol.GetNumAtoms()
                new_a = Chem.Atom(6)  # benzene has only carbon atoms
                new_a.SetAtomMapNum(atom_map)
                new_a.SetIsAromatic(True)
                new_a.SetBoolProp('is_edited', True)
                new_a.SetBoolProp('in_reactant', in_reactant)
                new_a.SetIntProp('mol_id', mol_id)
                mol.AddAtom(new_a)
                new_atom_ind.append(num_atoms)
        else:
            new_atom_ind.append(start_atom_ind)

    for i in range(len(new_atom_ind) - 1):
        bond = mol.GetBondBetweenAtoms(new_atom_ind[i], new_atom_ind[i + 1])
        if bond is None:
            bond_idx = mol.AddBond(new_atom_ind[i],
                                   new_atom_ind[i + 1],
                                   order=Chem.rdchem.BondType.AROMATIC) - 1
            bond = mol.GetBondWithIdx(bond_idx)
        bond.SetBoolProp('is_edited', True)

    bond = mol.GetBondBetweenAtoms(new_atom_ind[0], new_atom_ind[-1])
    if bond is None:
        bond_idx = mol.AddBond(new_atom_ind[0],
                               new_atom_ind[-1],
                               order=Chem.rdchem.BondType.AROMATIC) - 1
        bond = mol.GetBondWithIdx(bond_idx)
    bond.SetBoolProp('is_edited', True)

    return mol
Пример #9
0
def find_rings(mol: Mol) -> List[List[int]]:
    ring_info = mol.GetRingInfo()
    rings = ring_info.AtomRings()

    i2map = dict((i, a.GetAtomMapNum()) for i, a in enumerate(mol.GetAtoms()))
    rings_mapped = []

    for ring in rings:
        rings_mapped.append([i2map[i] for i in ring])

    return rings_mapped
Пример #10
0
def rdmol_to_data(mol: Mol):
    assert mol.GetNumConformers() == 1
    N = mol.GetNumAtoms()

    pos = torch.tensor(mol.GetConformer(0).GetPositions(), dtype=torch.float)

    atomic_number = []
    aromatic = []
    sp = []
    sp2 = []
    sp3 = []
    num_hs = []
    for atom in mol.GetAtoms():
        atomic_number.append(atom.GetAtomicNum())
        aromatic.append(1 if atom.GetIsAromatic() else 0)
        hybridization = atom.GetHybridization()
        sp.append(1 if hybridization == HybridizationType.SP else 0)
        sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
        sp3.append(1 if hybridization == HybridizationType.SP3 else 0)

    z = torch.tensor(atomic_number, dtype=torch.long)

    row, col, edge_type = [], [], []
    for bond in mol.GetBonds():
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        row += [start, end]
        col += [end, start]
        edge_type += 2 * [BOND_TYPES[bond.GetBondType()]]

    edge_index = torch.tensor([row, col], dtype=torch.long)
    edge_type = torch.tensor(edge_type)

    perm = (edge_index[0] * N + edge_index[1]).argsort()
    edge_index = edge_index[:, perm]
    edge_type = edge_type[perm]

    row, col = edge_index
    hs = (z == 1).to(torch.float)
    num_hs = scatter(hs[row], col, dim_size=N).tolist()

    smiles = Chem.MolToSmiles(mol)

    data = Data(node_type=z,
                pos=pos,
                edge_index=edge_index,
                edge_type=edge_type,
                rdmol=copy.deepcopy(mol),
                smiles=smiles)
    data.nx = to_networkx(data, to_undirected=True)

    return data
Пример #11
0
    def __init__(self,
                 source_mol: RWMol,
                 target_mol: Mol,
                 action_vocab: dict,
                 forward: bool = False,
                 action_order: str = 'dfs'):
        self.source_mol = source_mol
        self.target_mol = target_mol

        self.randomize_action_types = 'random' in action_order
        self.randomize_map_atom_order = action_order == 'random' or 'randat' in action_order
        self.randomize_next_atom = action_order == 'random'

        self.action_order = action_order

        self.atoms_stack = []
        if 'bfs' in self.action_order:
            for a in target_mol.GetAtoms():
                self.atoms_stack.append(a.GetAtomMapNum())
            self.atoms_stack = list(sorted(self.atoms_stack))

        mark_reactants(source_mol, target_mol)

        self.edited_atoms = set()
        self.forward = forward
        self.action_vocab = action_vocab
        self.prop_dict = action_vocab['prop2oh']

        self.added_rings = {
            'benzene':
            find_added_benzene_rings(source_mol=source_mol,
                                     target_mol=target_mol)
        }
        self.current_step = 0
        self.current_mol_graph = get_graph(self.source_mol,
                                           ravel=False,
                                           to_array=True,
                                           atom_prop2oh=self.prop_dict['atom'],
                                           bond_prop2oh=self.prop_dict['bond'])
def build_atom_features_matrix(mol: Mol) -> np.ndarray:
    return np.array([get_atom_features(atom) for atom in mol.GetAtoms()])
Пример #13
0
def renumber_atoms_for_mapping(mol: Mol) -> Mol:
    new_order = []
    for a in mol.GetAtoms():
        new_order.append(a.GetAtomMapNum())
    new_order = [int(a) for a in np.argsort(new_order)]
    return RenumberAtoms(mol, new_order)
Пример #14
0
def get_atom_ind(mol: Mol, atom_map: int) -> int:
    for i, a in enumerate(mol.GetAtoms()):
        if a.GetAtomMapNum() == atom_map:
            return i
    raise ValueError(f'No atom with map number: {atom_map}')
Пример #15
0
def display_numbered(mol: Mol):
    mol = deepcopy(mol)
    for atom in mol.GetAtoms():
        atom.SetAtomMapNum(atom.GetIdx())
    display(mol)