Esempio n. 1
0
    def get_USRlike_atoms(self):
        """Returns 4 rdkit Point3D objects similar to those used in USR:
        - centroid (ctd)
        - closest to ctd (cst)
        - farthest from cst (fct) (usually ctd but let's avoid computing too many dist matrices)
        - farthest from fct (ftf)"""
        matrix = rdmolops.Get3DDistanceMatrix(self.mol)
        conf = self.mol.GetConformer()
        coords = conf.GetPositions()

        # centroid
        ctd = rdMolTransforms.ComputeCentroid(conf)

        # closest to centroid
        min_dist = 100
        for atom in self.mol.GetAtoms():
            point = rdGeometry.Point3D(*coords[atom.GetIdx()])
            dist = ctd.Distance(point)
            if dist < min_dist:
                min_dist = dist
                cst = point
                cst_idx = atom.GetIdx()

        # farthest from cst
        fct_idx = argmax(matrix[cst_idx])
        fct = rdGeometry.Point3D(*coords[fct_idx])

        # farthest from fct
        ftf_idx = argmax(matrix[fct_idx])
        ftf = rdGeometry.Point3D(*coords[ftf_idx])

        return ctd, cst, fct, ftf
def construct_distance_matrix(mol, out_size=-1, contain_Hs=False):
    """Construct distance matrix

    Args:
        mol (Chem.Mol):
        out_size (int):
        contain_Hs (bool):

    Returns (numpy.ndarray): 2 dimensional array which represents distance
        between atoms

    """
    if mol is None:
        raise MolFeatureExtractionError('mol is None')

    N = mol.GetNumAtoms()
    if out_size < 0:
        size = N
    elif out_size >= N:
        size = out_size
    else:
        raise MolFeatureExtractionError('out_size {} is smaller than number '
                                        'of atoms in mol {}'.format(
                                            out_size, N))

    if contain_Hs:
        mol2 = mol
    else:
        mol2 = AllChem.AddHs(mol)

    conf_id = AllChem.EmbedMolecule(mol2)
    if not contain_Hs:
        mol2 = AllChem.RemoveHs(mol2)

    try:
        dist_matrix = rdmolops.Get3DDistanceMatrix(mol2, confId=conf_id)
    except ValueError as e:
        logger = getLogger(__name__)
        logger.info('construct_distance_matrix failed, type: {}, {}'.format(
            type(e).__name__, e.args))
        logger.debug(traceback.format_exc())
        raise MolFeatureExtractionError

    if size > N:
        dists = numpy.zeros((size, size), dtype=numpy.float32)
        a0, a1 = dist_matrix.shape
        dists[:a0, :a1] = dist_matrix
    else:
        dists = dist_matrix
    return dists.astype(numpy.float32)
def construct_distance_matrix(mol, out_size=-1):
    """Construct distance matrix

    Args:
        mol (Chem.Mol):
        out_size (int):

    Returns:

    """
    if mol is None:
        raise MolFeatureExtractionError('mol is None')
    N = mol.GetNumAtoms()

    if out_size < 0:
        size = N
    elif out_size >= N:
        size = out_size
    else:
        raise MolFeatureExtractionError('out_size {} is smaller than number '
                                        'of atoms in mol {}'.format(
                                            out_size, N))

    confid = AllChem.EmbedMolecule(mol)
    try:
        dist_matrix = rdmolops.Get3DDistanceMatrix(mol, confId=confid)
    except ValueError as e:
        logger = getLogger(__name__)
        logger.info('construct_distance_matrix failed, type: {}, {}'.format(
            type(e).__name__, e.args))
        logger.debug(traceback.format_exc())
        raise MolFeatureExtractionError

    if size > N:
        dists = numpy.zeros((size, size), dtype=numpy.float32)
        a0, a1 = dist_matrix.shape
        dists[:a0, :a1] = dist_matrix
    else:
        dists = dist_matrix
    return dists.astype(numpy.float32)
def construct_pair_feature(mol, use_all_feature):
    """construct pair feature

    Args:
        mol (Mol): mol instance
        use_all_feature (bool):
            If True, all pair features are extracted.
            If False, a part of pair features is extracted.
            You can confirm the detail in the paper.

    Returns:
        features (numpy.ndarray): The shape is (num_edges, num_edge_features)
        bond_idx (numpy.ndarray): The shape is (2, num_edges)
            bond_idx[0] represents the list of StartNodeIdx and bond_idx[1]
            represents the list of EndNodeIdx.
    """
    converter = GaussianDistance()

    # prepare the data for extracting the pair feature
    bonds = mol.GetBonds()
    graph_distance_matrix = Chem.GetDistanceMatrix(mol)
    is_in_ring = get_is_in_ring(mol)
    confid = AllChem.EmbedMolecule(mol)
    try:
        coordinate_matrix = rdmolops.Get3DDistanceMatrix(
            mol, confId=confid)
    except ValueError as e:
        logger = getLogger(__name__)
        logger.info('construct_distance_matrix failed, type: {}, {}'
                    .format(type(e).__name__, e.args))
        logger.debug(traceback.format_exc())
        raise MolFeatureExtractionError

    feature = []
    bond_idx = []
    for bond in bonds:
        start_node = bond.GetBeginAtomIdx()
        end_node = bond.GetEndAtomIdx()

        # create pair feature
        distance_feature = numpy.array(
            graph_distance_matrix[start_node][end_node], dtype=numpy.float32)
        bond_feature = construct_bond_vec(mol, start_node, end_node)
        ring_feature = construct_ring_feature_vec(
            is_in_ring, start_node, end_node)

        bond_idx.append((start_node, end_node))
        if use_all_feature:
            expanded_distance_feature = \
                construct_expanded_distance_vec(
                    coordinate_matrix, converter, start_node, end_node)
            feature.append(numpy.hstack((bond_feature, ring_feature,
                                         distance_feature,
                                         expanded_distance_feature)))
        else:
            expanded_distance_feature = \
                construct_expanded_distance_vec(
                    coordinate_matrix, converter, start_node, end_node)
            feature.append(expanded_distance_feature)

    bond_idx = numpy.array(bond_idx).T
    feature = numpy.array(feature)
    return feature, bond_idx
Esempio n. 5
0
def perform_lrp(model, hyper, trial=0, sample=None, epsilon=0.1, gamma=0.1):
    tf.config.experimental.set_memory_growth(
        tf.config.experimental.list_physical_devices('GPU')[0], True)

    # Make folder
    fig_path = "../analysis/{}".format(model)
    if not os.path.isdir(fig_path):
        os.mkdir(fig_path)
    fig_path = "../analysis/{}/{}".format(model, hyper)
    if not os.path.isdir(fig_path):
        os.mkdir(fig_path)
    fig_path = "../analysis/{}/{}/heatmap".format(model, hyper)
    if not os.path.isdir(fig_path):
        os.mkdir(fig_path)

    # Load results
    base_path = "../result/{}/{}/".format(model, hyper)
    path = base_path + 'trial_{:02d}/'.format(trial)

    # Load hyper
    with open(path + 'hyper.csv', newline='') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            hyper = dict(row)

    # Load model
    custom_objects = {
        'NodeEmbedding': NodeEmbedding,
        'GraphConvolution': GraphConvolution,
        'Normalize': Normalize,
        'GlobalPooling': GlobalPooling
    }
    model = load_model(path + 'best_model.h5', custom_objects=custom_objects)
    print([l.name for l in model.layers])

    # Load data
    data = np.load(path + 'data_split.npz')
    dataset = Dataset('refined', 5)
    if sample is not None:
        dataset.split_by_idx(32, data['train'], data['valid'],
                             data['test'][sample])
    else:
        dataset.split_by_idx(32, data['train'], data['valid'], data['test'])
    data.close()

    # Predict
    true_y = dataset.test_y
    outputs = {}
    for layer_name in [
            'node_embedding', 'node_embedding_1', 'normalize', 'normalize_1',
            'activation', 'add', 'activation_1', 'add_1', 'global_pooling',
            'activation_2', 'activation_3', 'activation_4',
            'atom_feature_input'
    ]:
        sub_model = tf.keras.models.Model(
            inputs=model.input, outputs=model.get_layer(layer_name).output)
        outputs[layer_name] = sub_model.predict(dataset.test,
                                                steps=dataset.test_step,
                                                verbose=0)[:len(true_y)]

    # Output layer: LRP-0
    # print('Calculating Dense_2...')
    relevance = lrp_dense(outputs['activation_3'],
                          outputs['activation_4'],
                          model.get_layer('dense_2').get_weights()[0],
                          model.get_layer('dense_2').get_weights()[1],
                          epsilon=0)

    # Dense layer: LRP-e
    # print('Calculating Dense_1...')
    relevance = lrp_dense(outputs['activation_2'],
                          relevance,
                          model.get_layer('dense_1').get_weights()[0],
                          model.get_layer('dense_1').get_weights()[1],
                          epsilon=epsilon)

    # Dense layer: LRP-e
    # print('Calculating Dense_0...')
    relevance = lrp_dense(outputs['global_pooling'],
                          relevance,
                          model.get_layer('dense').get_weights()[0],
                          model.get_layer('dense').get_weights()[1],
                          epsilon=epsilon)

    # Pooling layer
    # print('Calculating Pooling...')
    relevance = lrp_pooling(outputs['activation_1'], relevance)

    # Add layer
    # print('Calculating Add_1...')
    relevance_1, relevance_2 = lrp_add(
        [outputs['add'], outputs['activation_1']], relevance)

    # GCN layer: LRP-g
    # print('Calculating GCN_1...')
    relevance = lrp_gcn_gamma(
        outputs['add'],
        relevance_2,
        outputs['normalize_1'],
        model.get_layer('graph_convolution_1').get_weights()[0],
        gamma=gamma) + relevance_1

    # Add layer
    # print('Calculating Add_0...')
    relevance_1, relevance_2 = lrp_add(
        [outputs['graph_embedding_1'], outputs['activation']], relevance)

    # GCN layer: LRP-g
    # print('Calculating GCN_0...')
    relevance = lrp_gcn_gamma(
        outputs['graph_embedding_1'],
        relevance_2,
        outputs['normalize'],
        model.get_layer('graph_convolution').get_weights()[0],
        gamma=gamma) + relevance_1

    # Embedding layer : LRP-e
    # print('Calculating Embedding_1...')
    relevance = lrp_dense(
        outputs['graph_embedding'],
        relevance,
        model.get_layer('graph_embedding_1').get_weights()[0],
        model.get_layer('graph_embedding_1').get_weights()[1],
        epsilon=epsilon)

    # Embedding layer : LRP-e
    # print('Calculating Embedding_0...')
    relevance = lrp_dense(outputs['atom_feature_input'],
                          relevance,
                          model.get_layer('graph_embedding').get_weights()[0],
                          model.get_layer('graph_embedding').get_weights()[1],
                          epsilon=epsilon)

    relevance = tf.math.reduce_sum(relevance, axis=-1).numpy()
    relevance = np.divide(relevance, np.expand_dims(true_y, -1))

    # Preset
    DrawingOptions.bondLineWidth = 1.5
    DrawingOptions.elemDict = {}
    DrawingOptions.dotsPerAngstrom = 20
    DrawingOptions.atomLabelFontSize = 4
    DrawingOptions.atomLabelMinFontSize = 4
    DrawingOptions.dblBondOffset = 0.3
    DrawingOptions.wedgeDashedBonds = False

    # Load data
    dataframe = pd.read_pickle('../data/5A.pkl')
    if sample is not None:
        test_set = np.load(path + 'data_split.npz')['test'][sample]
    else:
        test_set = np.load(path + 'data_split.npz')['test']

    # Draw images for test molecules
    colormap = cm.get_cmap('seismic')
    for idx, test_idx in enumerate(test_set):
        print('Drawing figure for {}/{}'.format(idx, len(test_set)))
        pdb_code = dataframe.iloc[test_idx]['code']
        error = np.absolute(dataframe.iloc[test_idx]['output'] -
                            outputs['activation_4'][idx])[0]
        if error > 0.2: continue

        for mol_ligand, mol_pocket in zip(
                Chem.SDMolSupplier(
                    '../data/refined-set/{}/{}_ligand.sdf'.format(
                        pdb_code, pdb_code)),
                Chem.SDMolSupplier(
                    '../data/refined-set/{}/{}_pocket.sdf'.format(
                        pdb_code, pdb_code))):

            # Crop atoms
            mol = Chem.CombineMols(mol_ligand, mol_pocket)
            distance = np.array(rdmolops.Get3DDistanceMatrix(mol))
            cropped_idx = np.argwhere(
                np.min(distance[:, :mol_ligand.GetNumAtoms()], axis=1) <= 5
            ).flatten()
            unpadded_relevance = np.zeros((mol.GetNumAtoms(), ))
            np.put(unpadded_relevance, cropped_idx, relevance[idx])
            scale = max(max(unpadded_relevance),
                        math.fabs(min(unpadded_relevance))) * 3

            # Separate fragments in Combined Mol
            idxs_frag = rdmolops.GetMolFrags(mol)
            mols_frag = rdmolops.GetMolFrags(mol, asMols=True)

            # Draw fragment and interaction
            for i, (mol_frag,
                    idx_frag) in enumerate(zip(mols_frag[1:], idxs_frag[1:])):
                # Ignore water
                if mol_frag.GetNumAtoms() == 1:
                    continue

                # Generate 2D image
                mol_combined = Chem.CombineMols(mols_frag[0], mol_frag)
                AllChem.Compute2DCoords(mol_combined)
                fig = Draw.MolToMPL(mol_combined, coordScale=1)
                fig.axes[0].set_axis_off()

                # Draw line between close atoms (5A)
                flag = False
                for j in range(mol_ligand.GetNumAtoms()):
                    for k in idx_frag:
                        if distance[j, k] <= 5:
                            # Draw connection
                            coord_li = mol_combined._atomPs[j]
                            coord_po = mol_combined._atomPs[
                                idx_frag.index(k) + mols_frag[0].GetNumAtoms()]
                            x, y = np.array([[coord_li[0], coord_po[0]],
                                             [coord_li[1], coord_po[1]]])
                            line = Line2D(x,
                                          y,
                                          color='b',
                                          linewidth=1,
                                          alpha=0.3)
                            fig.axes[0].add_line(line)
                            flag = True

                # Draw heatmap for atoms
                for j in range(mol_combined.GetNumAtoms()):
                    relevance_li = unpadded_relevance[j]
                    relevance_li = relevance_li / scale + 0.5
                    highlight = plt.Circle(
                        (mol_combined._atomPs[j][0],
                         mol_combined._atomPs[j][1]),
                        0.035 * math.fabs(unpadded_relevance[j] / scale) +
                        0.008,
                        color=colormap(relevance_li),
                        alpha=0.8,
                        zorder=0)
                    fig.axes[0].add_artist(highlight)

                # Save
                if flag:
                    fig_name = fig_path + '/{}_lrp_{}_{}_{}.png'.format(
                        trial, test_idx, pdb_code, i)
                    fig.savefig(fig_name, bbox_inches='tight')
                plt.close(fig)
Esempio n. 6
0
    def parse_dataset(self):
        def _one_hot(x, allowable_set):
            return list(map(lambda s: x == s, allowable_set))

        # Get total types of atoms
        for ligand, pocket in zip(self.x_ligand, self.x_pocket):
            mol = Chem.CombineMols(ligand, pocket)
            self.num_atoms = max(self.num_atoms, mol.GetNumAtoms())
            for atom in mol.GetAtoms():
                symbol = atom.GetSymbol()
                if symbol not in self.atom_type.keys():
                    self.atom_type[symbol] = 1
                else:
                    self.atom_type[symbol] += 1
        self.atom_type = {
            k: v
            for k, v in sorted(
                self.atom_type.items(), key=lambda item: item[1], reverse=True)
        }

        columns = [
            'code', 'symbol', 'atomic_num', 'degree', 'hybridization',
            'implicit_valence', 'formal_charge', 'aromaticity', 'ring_size',
            'num_hs', 'acid_base', 'h_donor_acceptor', 'adjacency_intra',
            'adjacency_inter', 'distance', 'output'
        ]
        self.dataframe = pd.DataFrame(columns=columns)

        hydrogen_donor = Chem.MolFromSmarts(
            "[$([N;!H0;v3,v4&+1]),$([O,S;H1;+0]),n&H1&+0]")
        hydrogen_acceptor = Chem.MolFromSmarts(
            "[$([O,S;H1;v2;!$(*-*=[O,N,P,S])]),$([O,S;H0;v2]),$([O,S;-]),$([N;v3;!$(N-*=[O,N,P,S])]),n&H0&+0,$([o,s;+0;!$([o,s]:n);!$([o,s]:c:n)])]"
        )
        acidic = Chem.MolFromSmarts("[$([C,S](=[O,S,P])-[O;H1,-1])]")
        basic = Chem.MolFromSmarts(
            "[#7;+,$([N;H2&+0][$([C,a]);!$([C,a](=O))]),$([N;H1&+0]([$([C,a]);!$([C,a](=O))])[$([C,a]);!$([C,a](=O))]),$([N;H0&+0]([C;!$(C(=O))])([C;!$(C(=O))])[C;!$(C(=O))])]"
        )

        for ligand, pocket in zip(self.x_ligand, self.x_pocket):
            n_ligand = ligand.GetNumAtoms()
            mol = Chem.CombineMols(ligand, pocket)

            # Crop atoms
            adjacency = np.array(rdmolops.Get3DDistanceMatrix(mol))
            idx = np.argwhere(
                np.min(adjacency[:, :n_ligand], axis=1) <= self.cutoff
            ).flatten().tolist()

            # Get tensors
            Chem.AssignStereochemistry(mol)
            hydrogen_donor_match = sum(mol.GetSubstructMatches(hydrogen_donor),
                                       ())
            hydrogen_acceptor_match = sum(
                mol.GetSubstructMatches(hydrogen_acceptor), ())
            acidic_match = sum(mol.GetSubstructMatches(acidic), ())
            basic_match = sum(mol.GetSubstructMatches(basic), ())
            ring = mol.GetRingInfo()

            m = [[], [], [], [], [], [], [], [], [], [], [], [], [], [], [],
                 []]
            m[0] = ligand.GetProp('_Name').split('_')[0]
            for atom_idx in idx:
                atom = mol.GetAtomWithIdx(atom_idx)
                m[1].append(_one_hot(atom.GetSymbol(), self.atom_type.keys()))
                m[2].append([atom.GetAtomicNum()])
                m[3].append(_one_hot(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6]))
                m[4].append(
                    _one_hot(atom.GetHybridization(), [
                        Chem.rdchem.HybridizationType.SP,
                        Chem.rdchem.HybridizationType.SP2,
                        Chem.rdchem.HybridizationType.SP3,
                        Chem.rdchem.HybridizationType.SP3D,
                        Chem.rdchem.HybridizationType.SP3D2
                    ]))
                m[5].append(
                    _one_hot(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]))
                m[6].append(
                    _one_hot(atom.GetFormalCharge(), [-3, -2, -1, 0, 1, 2, 3]))
                m[7].append([atom.GetIsAromatic()])
                m[8].append([
                    ring.IsAtomInRingOfSize(atom_idx, 3),
                    ring.IsAtomInRingOfSize(atom_idx, 4),
                    ring.IsAtomInRingOfSize(atom_idx, 5),
                    ring.IsAtomInRingOfSize(atom_idx, 6),
                    ring.IsAtomInRingOfSize(atom_idx, 7),
                    ring.IsAtomInRingOfSize(atom_idx, 8)
                ])
                m[9].append(_one_hot(atom.GetTotalNumHs(), [0, 1, 2, 3, 4]))
                m[10].append(
                    [atom_idx in acidic_match, atom_idx in basic_match])
                m[11].append([
                    atom_idx in hydrogen_donor_match, atom_idx
                    in hydrogen_acceptor_match
                ])
            m[12] = np.array(rdmolops.GetAdjacencyMatrix(mol))[idx][:, idx]
            adj = np.zeros_like(m[12])
            adj[:n_ligand, n_ligand:] = 1.
            adj[n_ligand:, :n_ligand] = 1.
            m[13] = adj
            m[14] = np.array(rdmolops.Get3DDistanceMatrix(mol))[idx][:, idx]
            m[15] = float(ligand.GetProp('target_calc'))

            self.dataframe = self.dataframe.append(pd.DataFrame(
                [m], columns=columns),
                                                   ignore_index=True,
                                                   sort=True)

        # Pad data
        self.num_atoms = 0
        for i in range(len(self.dataframe)):
            self.num_atoms = max(len(self.dataframe.iloc[i]['symbol']),
                                 self.num_atoms)

        for i in range(len(self.dataframe)):
            # ['acid_base', 'adjacency_inter', 'adjacency_intra', 'aromaticity', 'atomic_num', 'code', 'degree',
            #  'distance', 'formal_charge', 'h_donor_acceptor', 'hybridization', 'implicit_valence', 'num_hs',
            #  'output', 'ring_size', 'symbol']
            delta = self.num_atoms - len(self.dataframe.iat[i, 0])
            for j in [1, 2, 7]:
                self.dataframe.iat[i, j] = np.pad(self.dataframe.iat[i, j],
                                                  ((0, delta), (0, delta)),
                                                  'constant',
                                                  constant_values=((0, 0),
                                                                   (0, 0)))
            for j in [0, 3, 6, 8, 9, 10, 11, 12, 14, 15]:
                self.dataframe.iat[i, j] = np.pad(
                    self.dataframe.iat[i, j], ((0, delta), (0, 0)),
                    'constant',
                    constant_values=((False, False), (False, False)))

            self.dataframe.iat[i, 4] = np.pad(self.dataframe.iat[i, 4],
                                              ((0, delta), (0, 0)),
                                              'constant',
                                              constant_values=((0, 0), (0, 0)))
Esempio n. 7
0
def geometric_matrix(mol, conformer=-1):
    return rdmolops.Get3DDistanceMatrix(mol, confId=conformer)