def collate_fn(
    batch, k, p,
    ms=mol_spec.get_default_mol_spec()):  # batch = input [mol, mol, mol ...]
    # get the maximum number of atoms and bonds in this batch
    num_bonds_list = [mol.GetNumBonds() for mol in batch]
    max_num_steps = max(num_bonds_list) + 1

    mol_array, logp = [], []

    for mol_i in batch:
        # size:
        # mol_array_i : k x num_steps_i x 4
        # logp_i: k
        mol_array_i, logp_i = get_array_from_mol(mol_i, k, p, ms)

        # pad to the same length
        num_steps_i = mol_array_i.shape[1]
        mol_array_i = np.pad(mol_array_i,
                             pad_width=[[0,
                                         0], [0, max_num_steps - num_steps_i],
                                        [0, 0]],
                             mode='constant',
                             constant_values=-1)

        mol_array.append(mol_array_i)
        logp.append(logp_i)

    mol_array = np.stack(mol_array, axis=0)
    logp = np.stack(logp, axis=0)

    # Output size:
    # mol_array: batch_size x k x max_num_steps x 4
    # logp: batch_size x k

    return mol_array, logp
def get_mol_from_array(mol_array,
                       sanitize=True,
                       ms=mol_spec.get_default_mol_spec()):
    """
    Converting molecule array to Chem.Mol objects

    Parameters
    ----------
        mol_array : np.ndarray
            The array representation of molecules
            dtype: int, shape: [num_samples, num_steps, 4]
        sanitize : bool
            Whether to sanitize the output molecule, default to True
        ms : mol_spec.MoleculeSpec

    Returns
    -------
        list[Chem.Mol]
            The list of output molecules
    """

    # get shape information
    num_samples, max_num_steps, _ = mol_array.shape

    # initialize the list of output molecules
    mol_list = []

    # loop over molecules
    for mol_id in range(num_samples):
        try:
            mol = Chem.RWMol(Chem.Mol())  # initialize molecule
            for step_id in range(max_num_steps):
                atom_type, begin_ids, end_ids, bond_type = mol_array[
                    mol_id, step_id, :].tolist()
                if end_ids == -1:
                    # if the actions is to terminate
                    break
                elif begin_ids == -1:
                    # if the action is to initialize
                    new_atom = ms.index_to_atom(atom_type)
                    mol.AddAtom(new_atom)
                elif atom_type == -1:
                    # if the action is to connect
                    ms.index_to_bond(mol, begin_ids, end_ids, bond_type)
                else:
                    # if the action is to append new atom
                    new_atom = ms.index_to_atom(atom_type)
                    mol.AddAtom(new_atom)
                    ms.index_to_bond(mol, begin_ids, end_ids, bond_type)
            if sanitize:
                mol = mol.GetMol()
                Chem.SanitizeMol(mol)
        except:
            mol = None
        mol_list.append(mol)

    return mol_list
def graph_to_c_scaffold(graph):
    chem = data_struct.get_default_mol_spec()
    mol = Chem.RWMol(Chem.Mol())
    for i in range(len(graph.nodes)):
        a = chem.index_to_atom(0)
        a.SetProp('molAtomMapNumber', str(i))
        mol.AddAtom(a)
    for begin_id, end_id in graph.edges:
        mol.AddBond(begin_id, end_id, Chem.rdchem.BondType.SINGLE)
    return AllChem.MolToSmiles(mol)
def get_mol_from_graph(symbol_charge_hs, bond_start_end, sanitize=True):

    chem = data_struct.get_default_mol_spec()
    mol = Chem.RWMol(Chem.Mol())
    for atom in symbol_charge_hs:
        mol.AddAtom(chem.index_to_atom(chem.atom_types.index(atom)))
    for bond in bond_start_end:
        chem.index_to_bond(mol, bond[0], bond[1], bond[2])

    if sanitize:
        try:
            mol = mol.GetMol()
            Chem.SanitizeMol(mol)
            return mol
        except:
            return None
    else:
        return None
def get_array_from_mol(mol,
                       num_samples=1,
                       p=0.9,
                       ms=mol_spec.get_default_mol_spec()):

    atom_types, bond_info = [], []
    num_atoms, num_bonds = mol.GetNumAtoms(), mol.GetNumBonds()

    for atom_id, atom in enumerate(mol.GetAtoms()):
        atom_types.append(ms.get_atom_type(atom))

    for bond_id, bond in enumerate(mol.GetBonds()):
        bond_info.append([
            bond.GetBeginAtomIdx(),
            bond.GetEndAtomIdx(),
            ms.get_bond_type(bond)
        ])

    # shape:
    # atom_types: num_atoms
    # bond_info: num_bonds x 3
    atom_types, bond_info = np.array(atom_types, dtype=np.int32), \
                            np.array(bond_info, dtype=np.int32)

    # sample route
    route_list, step_ids_list, logp = sample_ordering(mol, num_samples, p, ms)
    # initialize paced molecule array data
    mol_array = []

    for sample_id in range(num_samples):
        # get the route and step_ids for the i-th sample
        route_i, step_ids_i = route_list[sample_id, :], step_ids_list[
            sample_id, :]
        # reorder atom types and bond info
        # note: bond_info [start_ids, end_ids, bond_type]
        atom_types_i, bond_info_i, is_append = reorder(atom_types, bond_info,
                                                       route_i, step_ids_i)
        # atom type added at each step
        # -1 if the current step is connect
        atom_types_added = np.full([
            num_bonds,
        ], -1, dtype=np.int32)
        atom_types_added[is_append] = atom_types_i[bond_info_i[:,
                                                               1]][is_append]
        # pack into mol_array_i
        # size: num_bonds x 4
        # note: [atom_types_added, start_ids, end_ids, bond_type]
        mol_array_i = np.concatenate(
            [atom_types_added[:, np.newaxis], bond_info_i], axis=-1)
        # add initialization step
        init_step = np.array([[atom_types_i[0], -1, 0, -1]], dtype=np.int32)
        # concat into mol_array
        # size: (num_bonds + 1) x 4
        mol_array_i = np.concatenate([init_step, mol_array_i], axis=0)
        mol_array.append(mol_array_i)

    mol_array = np.stack(mol_array,
                         axis=0)  # num_samples x (num_bonds + 1) x 4

    # Output size:
    # mol_array: num_samples x (num_bonds + 1) x 4
    # logp: num_samples

    return mol_array, logp
import multiprocessing as mp
import os
from copy import deepcopy
from os import path
import linecache

import numpy as np
import torch
from rdkit import Chem
from rdkit.Chem import AllChem
import networkx as nx

from data import data_struct as mol_spec
from data import data_struct

ms = mol_spec.get_default_mol_spec()
ATOM_SYMBOLS = ms.atom_symbols
# []
# with open(path.join(
#     path.dirname(__file__),
#     'datasets',
#     'atom_types.txt')
# ) as f:
#     for line in f.readlines():
#         line = line.strip().split(',')
#         ATOM_SYMBOLS.append(line[0])

BOND_ORDERS = ms.bond_orders
__all__ = [
    'mol_gen',
    'to_tensor',