Exemplo n.º 1
0
def sample_subgraphs(smiles, num_samples=10, frac=0.5, vis_dir=None):
    mol = Chem.MolFromSmiles(smiles)
    clusters, atom_cls = find_clusters(mol)
    cluster_sizes = [len(cls) for cls in clusters]
    p = np.array(cluster_sizes).astype('float')
    p /= p.sum()

    selected_atoms_list = []

    for n in range(num_samples):
        selected_clusters = np.random.choice(len(clusters), int(frac * len(clusters)), p=p, replace=False)
        selected_atoms = set()
        for i in selected_clusters:
            for j in clusters[i]:
                selected_atoms.add(j)

        minimum_smiles, _ = extract_subgraph(smiles, selected_atoms)
        selected_atoms_list.append(selected_atoms)
        if vis_dir is not None:
            png_f = f'subgraph_{n}.png'
            Draw.MolToFile(mol, filename=os.path.join(vis_dir, png_f), highlightAtoms=selected_atoms)
            png_f = f'subgraph_{n}_extracted.png'
            print(minimum_smiles)
            Draw.MolToFile(Chem.MolFromSmiles(minimum_smiles), filename=os.path.join(vis_dir, png_f))
    return selected_atoms_list
Exemplo n.º 2
0
def mcts_rollout(node, state_map, orig_smiles, clusters, atom_cls, nei_cls,
                 scoring_function):
    #print('cur_node', node.smiles, node.P, node.N, node.W)
    cur_atoms = node.atoms
    if len(cur_atoms) <= MIN_ATOMS:
        return node.P

    # Expand if this node has never been visited
    if len(node.children) == 0:
        cur_cls = set([i for i, x in enumerate(clusters) if x <= cur_atoms])
        for i in cur_cls:
            leaf_atoms = [
                a for a in clusters[i] if len(atom_cls[a] & cur_cls) == 1
            ]
            if len(nei_cls[i] & cur_cls) == 1 or len(
                    clusters[i]) == 2 and len(leaf_atoms) == 1:
                new_atoms = cur_atoms - set(leaf_atoms)
                new_smiles, _ = extract_subgraph(orig_smiles, new_atoms)
                #print('new_smiles', node.smiles, '->', new_smiles)
                if new_smiles in state_map:
                    new_node = state_map[new_smiles]  # merge identical states
                else:
                    new_node = MCTSNode(new_smiles, new_atoms)
                if new_smiles:
                    node.children.append(new_node)

        state_map[node.smiles] = node
        if len(node.children) == 0: return node.P  # cannot find leaves

        scores = scoring_function([x.smiles for x in node.children])
        for child, score in zip(node.children, scores):
            child.P = score

    sum_count = sum([c.N for c in node.children])
    selected_node = max(node.children, key=lambda x: x.Q() + x.U(sum_count))
    v = mcts_rollout(selected_node, state_map, orig_smiles, clusters, atom_cls,
                     nei_cls, scoring_function)
    selected_node.W += v
    selected_node.N += 1
    return v
Exemplo n.º 3
0
def extract_selected_subgraph(smiles, selected_atoms):
    mol = Chem.MolFromSmiles(smiles)
    clusters, atom_cls = find_clusters(mol)
    selected_clusters = []

    for cls in clusters:
        if len(cls) > 2:
            num_selected = 0
            for atom in cls:
                num_selected += atom in selected_atoms
            if num_selected >= 2:
                # print('select the whole aromatic ring since 2 or more atoms are selected')
                selected_clusters.append(cls)

    for cls in selected_clusters:
        for atom in cls:
            selected_atoms.add(atom)

    minimum_smiles, _ = extract_subgraph(smiles, selected_atoms)
    # print(selected_atoms)
    # print(f'{smiles} --> {minimum_smiles}')

    return minimum_smiles
Exemplo n.º 4
0
def extract_selected_subgraph_for_gcn(smiles, selected_atoms, vis_dir=None):
    mol = Chem.MolFromSmiles(smiles)
    clusters, atom_cls = find_clusters(mol)
    selected_clusters = set()

    for atom in selected_atoms:
        for cls in atom_cls[atom]:
            selected_clusters.add(clusters[cls])
    # print(selected_clusters)

    for cls in selected_clusters:
        for atom in cls:
            selected_atoms.add(atom)

    minimum_smiles, _ = extract_subgraph(smiles, selected_atoms)
    # print(selected_atoms)
    # print(f'{smiles} --> {minimum_smiles}')
    if vis_dir is not None:
        png_f = f'atoms_selected{len(selected_atoms)}.png'
        Draw.MolToFile(mol, filename=os.path.join(vis_dir, png_f), highlightAtoms=selected_atoms)
        # png_f = f'atoms_minimum_extracted{len(selected_atoms)}.png'
        # Draw.MolToFile(Chem.MolFromSmiles(minimum_smiles), filename=os.path.join(vis_dir, png_f))

    return minimum_smiles
Exemplo n.º 5
0
def find_minimum_subgraph(smiles, selected_atoms, vis_dir=None):
    mol = Chem.MolFromSmiles(smiles)
    clusters, atom_cls = find_clusters(mol)
    selected_clusters = set()
    cluster_votes = {}
    # First iteration: select a cluster when,
    #   1. An atom uniquely belongs to this cluster,
    #   2. Two atoms belong to this cluster.
    for atom in selected_atoms:
        assert len(atom_cls[atom]) > 0
        if len(atom_cls[atom]) == 1:
            selected_clusters.add(atom_cls[atom][0])
        else:
            for cls in atom_cls[atom]:
                if cls not in cluster_votes:
                    cluster_votes[cls] = 0
                cluster_votes[cls] += 1
                if cluster_votes[cls] >= 2:
                    selected_clusters.add(cls)
    # Second iteration: randomly select a cluster for the remaining atoms.
    for atom in selected_atoms:
        selected = False
        for cls in atom_cls[atom]:
            if cls in selected_clusters:
                selected = True
                break
        if not selected:
            selected_clusters.add(atom_cls[atom][0])

    cluster_neighbor = {}
    for i in range(len(clusters)):
        cluster_neighbor[i] = set()
        for atom in clusters[i]:
            cluster_neighbor[i].update(atom_cls[atom])
        cluster_neighbor[i].remove(i)

    # remove degree-1 unselected clusters iteratively
    leaf_clusters = set()
    while True:
        updated = False
        for i in range(len(clusters)):
            if i in selected_clusters or i in leaf_clusters:
                continue
            if len(cluster_neighbor[i]) > 1:
                removable = True
                neighbor_pairs = [(j, k) for j in cluster_neighbor[i] for k in cluster_neighbor[i] if j < k]
                for j, k in neighbor_pairs:
                    if j not in cluster_neighbor[k] or k not in cluster_neighbor[j]:
                        removable = False
                        break
                if not removable:
                    continue

            leaf_clusters.add(i)
            for j in cluster_neighbor[i]:
                cluster_neighbor[j].remove(i)
            updated = True

        if not updated:
            break

    minimum_atoms = set()
    for i in range(len(clusters)):
        if i not in leaf_clusters:
            minimum_atoms.update(clusters[i])

    minimum_smiles, _ = extract_subgraph(smiles, minimum_atoms)
    # print(f'{smiles} --> {minimum_smiles}')

    if vis_dir is not None:
        png_f = f'atoms_selected{len(selected_atoms)}.png'
        Draw.MolToFile(mol, filename=os.path.join(vis_dir, png_f), highlightAtoms=selected_atoms)
        png_f = f'atoms_minimum{len(selected_atoms)}.png'
        Draw.MolToFile(mol, filename=os.path.join(vis_dir, png_f), highlightAtoms=minimum_atoms)
        png_f = f'atoms_minimum_extracted{len(selected_atoms)}.png'
        Draw.MolToFile(Chem.MolFromSmiles(minimum_smiles), filename=os.path.join(vis_dir, png_f))

    return minimum_smiles
import sys
from rdkit import Chem
from multiobj_rationale.fuseprop import enum_subgraph, extract_subgraph

ratio_list = [0.3, 0.4, 0.5, 0.6, 0.7]

next(sys.stdin)
for line in sys.stdin:
    smiles = line.strip("\r\n ").split(',')[0]
    mol = Chem.MolFromSmiles(smiles)
    selections = enum_subgraph(mol, ratio_list)

    res = []
    for selected_atoms in selections:
        subgraph, _ = extract_subgraph(smiles, selected_atoms)
        if subgraph is not None:
            res.append(subgraph)

    for subgraph in set(res):
        print(smiles, subgraph)