Exemple #1
0
    def _pre_process(self, smile2graph):
        """Pre-process the dataset

        * Convert molecules from smiles format into DGLGraphs
          and featurize their atoms
        * Set missing labels to be 0 and use a binary masking
          matrix to mask them
        """
        if os.path.exists(self.cache_file_path):
            # DGLGraphs have been constructed before, reload them
            print('Loading previously saved dgl graphs...')
            with open(self.cache_file_path, 'rb') as f:
                self.graphs = pickle.load(f)
        else:
            self.graphs = []
            for id, s in enumerate(self.smiles):
                self.graphs.append(smile2graph(s))

            with open(self.cache_file_path, 'wb') as f:
                pickle.dump(self.graphs, f)

        _label_values = self.df[self.task_names].values
        # np.nan_to_num will also turn inf into a very large number
        self.labels = F.zerocopy_from_numpy(np.nan_to_num(_label_values))
        self.mask = F.zerocopy_from_numpy(
            ~np.isnan(_label_values).astype(np.float32))
Exemple #2
0
    def _pre_process(self, smiles_to_graph, node_featurizer, edge_featurizer,
                     load, log_every):
        """Pre-process the dataset

        * Convert molecules from smiles format into DGLGraphs
          and featurize their atoms
        * Set missing labels to be 0 and use a binary masking
          matrix to mask them

        Parameters
        ----------
        smiles_to_graph : callable, SMILES -> DGLGraph
            Function for converting a SMILES (str) into a DGLGraph.
        node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
            Featurization for nodes like atoms in a molecule, which can be used to update
            ndata for a DGLGraph.
        edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
            Featurization for edges like bonds in a molecule, which can be used to update
            edata for a DGLGraph.
        load : bool
            Whether to load the previously pre-processed dataset or pre-process from scratch.
            ``load`` should be False when we want to try different graph construction and
            featurization methods and need to preprocess from scratch. Default to True.
        log_every : bool
            Print a message every time ``log_every`` molecules are processed.
        """
        if os.path.exists(self.cache_file_path) and load:
            # DGLGraphs have been constructed before, reload them
            print('Loading previously saved dgl graphs...')
            self.graphs, label_dict = load_graphs(self.cache_file_path)
            self.labels = label_dict['labels']
            self.mask = label_dict['mask']
        else:
            print('Processing dgl graphs from scratch...')
            self.graphs = []
            for i, s in enumerate(self.smiles):
                if (i + 1) % log_every == 0:
                    print('Processing molecule {:d}/{:d}'.format(
                        i + 1, len(self)))
                self.graphs.append(
                    smiles_to_graph(s,
                                    node_featurizer=node_featurizer,
                                    edge_featurizer=edge_featurizer))
            _label_values = self.df[self.task_names].values
            # np.nan_to_num will also turn inf into a very large number
            self.labels = F.zerocopy_from_numpy(
                np.nan_to_num(_label_values).astype(np.float32))
            self.mask = F.zerocopy_from_numpy(
                (~np.isnan(_label_values)).astype(np.float32))
            save_graphs(self.cache_file_path,
                        self.graphs,
                        labels={
                            'labels': self.labels,
                            'mask': self.mask
                        })
Exemple #3
0
def _locate_eids_to_exclude(frontier_parent_eids, exclude_eids):
    """Find the edges whose IDs in parent graph appeared in exclude_eids.

    Note that both arguments are numpy arrays or numpy dicts.
    """
    if isinstance(frontier_parent_eids, Mapping):
        result = {
            k: np.isin(frontier_parent_eids[k], exclude_eids[k]).nonzero()[0]
            for k in frontier_parent_eids.keys() if k in exclude_eids.keys()
        }
        return {k: F.zerocopy_from_numpy(v) for k, v in result.items()}
    else:
        result = np.isin(frontier_parent_eids, exclude_eids).nonzero()[0]
        return F.zerocopy_from_numpy(result)
Exemple #4
0
    def __call__(self, mol):
        """Featurize all bonds in a molecule.

        Parameters
        ----------
        mol : rdkit.Chem.rdchem.Mol
            RDKit molecule instance.

        Returns
        -------
        dict
            For each function in self.featurizer_funcs with the key ``k``, store the computed
            feature under the key ``k``. Each feature is a tensor of dtype float32 and shape
            (N, M), where N is the number of atoms in the molecule.
        """
        num_bonds = mol.GetNumBonds()
        bond_features = defaultdict(list)

        # Compute features for each bond
        for i in range(num_bonds):
            bond = mol.GetBondWithIdx(i)
            for feat_name, feat_func in self.featurizer_funcs.items():
                feat = feat_func(bond)
                bond_features[feat_name].extend([feat, feat.copy()])

        # Stack the features and convert them to float arrays
        processed_features = dict()
        for feat_name, feat_list in bond_features.items():
            feat = np.stack(feat_list)
            processed_features[feat_name] = F.zerocopy_from_numpy(
                feat.astype(np.float32))

        return processed_features
Exemple #5
0
    def __getitem__(self, item):
        """Get the ith datapoint

        Returns
        -------
        str
            SMILES for the ith datapoint
        DGLGraph
            DGLGraph for the ith datapoint
        Tensor of dtype float32
            Labels of the datapoint for all tasks
        Tensor of dtype float32
            Weights of the datapoint for all tasks
        """
        return self.smiles[item], self.graphs[item], \
               F.zerocopy_from_numpy(self.labels[item]),  \
               F.zerocopy_from_numpy(self.mask[item])
Exemple #6
0
    def __call__(self, mol):
        """Featurizes the input molecule.

        Parameters
        ----------
        mol : rdkit.Chem.rdchem.Mol
            RDKit molecule instance.

        Returns
        -------
        dict
            Mapping atom_data_field as specified in the input argument to the atom
            features, which is a float32 tensor of shape (N, M), N is the number of
            atoms and M is the feature size.
        """
        atom_features = []

        AllChem.ComputeGasteigerCharges(mol)
        num_atoms = mol.GetNumAtoms()

        # Get information for donor and acceptor
        fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
        mol_featurizer = ChemicalFeatures.BuildFeatureFactory(fdef_name)
        mol_feats = mol_featurizer.GetFeaturesForMol(mol)
        is_donor, is_acceptor = self.get_donor_acceptor_info(mol_feats)

        # Get a symmetrized smallest set of smallest rings
        # Following the practice from Chainer Chemistry (https://github.com/chainer/
        # chainer-chemistry/blob/da2507b38f903a8ee333e487d422ba6dcec49b05/chainer_chemistry/
        # dataset/preprocessors/weavenet_preprocessor.py)
        sssr = Chem.GetSymmSSSR(mol)

        for i in range(num_atoms):
            atom = mol.GetAtomWithIdx(i)
            # Features that can be computed directly from RDKit atom instances, which is a list
            feats = self._featurizer(atom)
            # Donor/acceptor indicator
            feats.append(float(is_donor[i]))
            feats.append(float(is_acceptor[i]))
            # Count the number of rings the atom belongs to for ring size between 3 and 8
            count = [0 for _ in range(3, 9)]
            for ring in sssr:
                ring_size = len(ring)
                if i in ring and 3 <= ring_size <= 8:
                    count[ring_size - 3] += 1
            feats.extend(count)
            atom_features.append(feats)
        atom_features = np.stack(atom_features)

        return {
            self._atom_data_field:
            F.zerocopy_from_numpy(atom_features.astype(np.float32))
        }
Exemple #7
0
    def __getitem__(self, item):
        """Get datapoint with index

        Parameters
        ----------
        item : int
            Datapoint index

        Returns
        -------
        str
            SMILES for the ith datapoint
        DGLGraph
            DGLGraph for the ith datapoint
        Tensor of dtype float32
            Labels of the datapoint for all tasks
        Tensor of dtype float32
            Binary masks indicating the existence of labels for all tasks
        """
        return self.smiles[item], self.graphs[item], \
               F.zerocopy_from_numpy(self.labels[item]),  \
               F.zerocopy_from_numpy(self.mask[item])
Exemple #8
0
    def _pre_process(self, smiles_to_graph, atom_featurizer, bond_featurizer):
        """Pre-process the dataset

        * Convert molecules from smiles format into DGLGraphs
          and featurize their atoms
        * Set missing labels to be 0 and use a binary masking
          matrix to mask them

        Parameters
        ----------
        smiles_to_graph : callable, SMILES -> DGLGraph
            Function for converting a SMILES (str) into a DGLGraph.
        atom_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
            Featurization for atoms in a molecule, which can be used to update
            ndata for a DGLGraph.
        bond_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
            Featurization for bonds in a molecule, which can be used to update
            edata for a DGLGraph.
        """
        if os.path.exists(self.cache_file_path):
            # DGLGraphs have been constructed before, reload them
            print('Loading previously saved dgl graphs...')
            self.graphs, label_dict = load_graphs(self.cache_file_path)
            self.labels = label_dict['labels']
            self.mask = label_dict['mask']
        else:
            print('Processing dgl graphs from scratch...')
            self.graphs = []
            for i, s in enumerate(self.smiles):
                print('Processing molecule {:d}/{:d}'.format(i+1, len(self)))
                self.graphs.append(smiles_to_graph(s, atom_featurizer=atom_featurizer,
                                                   bond_featurizer=bond_featurizer))
            _label_values = self.df[self.task_names].values
            # np.nan_to_num will also turn inf into a very large number
            self.labels = F.zerocopy_from_numpy(np.nan_to_num(_label_values).astype(np.float32))
            self.mask = F.zerocopy_from_numpy((~np.isnan(_label_values)).astype(np.float32))
            save_graphs(self.cache_file_path, self.graphs,
                        labels={'labels': self.labels, 'mask': self.mask})
Exemple #9
0
    def pull_all(self, name):
        """Pull the whole data from KVServer

        Note that we assume the row Ids in ID is in the ascending order.
        
        Parameters
        ----------
        name : str
            data name

        Return
        ------
        tensor
            target data tensor
        """
        ID = F.zerocopy_from_numpy(np.arange(self._data_size[name]))
        return self.pull(name, ID)
Exemple #10
0
    def push_all(self, name, data):
        """Push the whole data to KVServer

        The push_all() API will partition message into different
        KVServer nodes automatically.

        Note that we assume the row Ids in ID is in the ascending order.

        Parameters
        ----------
        name : str
            data name
        data : tensor (mx.ndarray or torch.tensor)
            data tensor
        """
        ID = F.zerocopy_from_numpy(np.arange(F.shape(data)[0]))
        self.push(name, ID, data)
Exemple #11
0
 def __init__(self, dataset, args, weighting=False, ranks=64):
     triples = dataset.train
     self.g = ConstructGraph(triples, dataset.n_entities, args)
     num_train = len(triples[0])
     print('|Train|:', num_train)
     if ranks > 1 and args.rel_part:
         self.edge_parts, self.rel_parts = RelationPartition(triples, ranks)
     elif ranks > 1:
         self.edge_parts = RandomPartition(triples, ranks)
     else:
         self.edge_parts = [np.arange(num_train)]
     if weighting:
         # TODO: weight to be added
         count = self.count_freq(triples)
         subsampling_weight = np.vectorize(lambda h, r, t: np.sqrt(1 / (
             count[(h, r)] + count[(t, -r - 1)])))
         weight = subsampling_weight(src, etype_id, dst)
         self.g.edata['weight'] = F.zerocopy_from_numpy(weight)
 def __init__(self, dataset, args, weighting=False, ranks=64):
     triples = dataset.train
     print("|Train|:", len(triples))
     if ranks > 1 and args.rel_part:
         triples_list = RelationPartition(triples, ranks)
     elif ranks > 1:
         triples_list = RandomPartition(triples, ranks)
     else:
         triples_list = [triples]
     self.graphs = []
     for i, triples in enumerate(triples_list):
         g = ConstructGraph(triples, dataset.n_entities, i, args)
         if weighting:
             # TODO: weight to be added
             count = self.count_freq(triples)
             subsampling_weight = np.vectorize(lambda h, r, t: np.sqrt(1 / (
                 count[(h, r)] + count[(t, -r - 1)])))
             weight = subsampling_weight(src, etype_id, dst)
             g.edata["weight"] = F.zerocopy_from_numpy(weight)
             # to be added
         self.graphs.append(g)
Exemple #13
0
    def __call__(self, mol):
        """Featurize a molecule

        Parameters
        ----------
        mol : rdkit.Chem.rdchem.Mol
            RDKit molecule instance.

        Returns
        -------
        dict
            Atom features of shape (N, 74),
            where N is the number of atoms in the molecule
        """
        num_atoms = mol.GetNumAtoms()
        atom_features = []
        for i in range(num_atoms):
            atom = mol.GetAtomWithIdx(i)
            atom_features.append(self._featurize_atom(atom))
        atom_features = np.stack(atom_features)
        atom_features = F.zerocopy_from_numpy(atom_features.astype(np.float32))

        return {self.atom_data_field: atom_features}
def PN_graph_construction_and_featurization(ligand_mol,
                                            protein_mol,
                                            ligand_coordinates,
                                            protein_coordinates,
                                            max_num_ligand_atoms=None,
                                            max_num_protein_atoms=None,
                                            max_num_neighbors=4,
                                            distance_bins=[1.5, 2.5, 3.5, 4.5],
                                            strip_hydrogens=False):
    """Graph construction and featurization for `PotentialNet for Molecular Property Prediction
     <https://pubs.acs.org/doi/10.1021/acscentsci.8b00507>`__.

    Parameters
    ----------
    ligand_mol : rdkit.Chem.rdchem.Mol
        RDKit molecule instance.
    protein_mol : rdkit.Chem.rdchem.Mol
        RDKit molecule instance.
    ligand_coordinates : Float Tensor of shape (V1, 3)
        Atom coordinates in a ligand.
    protein_coordinates : Float Tensor of shape (V2, 3)
        Atom coordinates in a protein.
    max_num_ligand_atoms : int or None
        Maximum number of atoms in ligands for zero padding, which should be no smaller than
        ligand_mol.GetNumAtoms() if not None. If None, no zero padding will be performed.
        Default to None.
    max_num_protein_atoms : int or None
        Maximum number of atoms in proteins for zero padding, which should be no smaller than
        protein_mol.GetNumAtoms() if not None. If None, no zero padding will be performed.
        Default to None.
    max_num_neighbors : int
        Maximum number of neighbors allowed for each atom when constructing KNN graph. Default to 4.
    distance_bins : list of float
        Distance bins to determine the edge types.
        Edges of the first edge type are added between pairs of atoms whose distances are less than `distance_bins[0]`.
        The length matches the number of edge types to be constructed.
        Default `[1.5, 2.5, 3.5, 4.5]`.
    strip_hydrogens : bool
        Whether to exclude hydrogen atoms. Default to False.

    Returns
    -------
    complex_bigraph : DGLGraph
        Bigraph with the ligand and the protein (pocket) combined and canonical features extracted.
        The atom features are stored as DGLGraph.ndata['h'].
        The edge types are stored as DGLGraph.edata['e'].
        The bigraphs of the ligand and the protein are batched together as one complex graph.
    complex_knn_graph : DGLGraph
        K-nearest-neighbor graph with the ligand and the protein (pocket) combined and edge features extracted based on distances.
        The edge types are stored as DGLGraph.edata['e'].
        The knn graphs of the ligand and the protein are batched together as one complex graph.

    """

    assert ligand_coordinates is not None, 'Expect ligand_coordinates to be provided.'
    assert protein_coordinates is not None, 'Expect protein_coordinates to be provided.'
    if max_num_ligand_atoms is not None:
        assert max_num_ligand_atoms >= ligand_mol.GetNumAtoms(), \
            'Expect max_num_ligand_atoms to be no smaller than ligand_mol.GetNumAtoms(), ' \
            'got {:d} and {:d}'.format(max_num_ligand_atoms, ligand_mol.GetNumAtoms())
    if max_num_protein_atoms is not None:
        assert max_num_protein_atoms >= protein_mol.GetNumAtoms(), \
            'Expect max_num_protein_atoms to be no smaller than protein_mol.GetNumAtoms(), ' \
            'got {:d} and {:d}'.format(max_num_protein_atoms, protein_mol.GetNumAtoms())

    if strip_hydrogens:
        # Remove hydrogen atoms and their corresponding coordinates
        ligand_atom_indices_left = filter_out_hydrogens(ligand_mol)
        protein_atom_indices_left = filter_out_hydrogens(protein_mol)
        ligand_coordinates = ligand_coordinates.take(ligand_atom_indices_left,
                                                     axis=0)
        protein_coordinates = protein_coordinates.take(
            protein_atom_indices_left, axis=0)
    else:
        ligand_atom_indices_left = list(range(ligand_mol.GetNumAtoms()))
        protein_atom_indices_left = list(range(protein_mol.GetNumAtoms()))

    # Node featurizer for stage 1
    atoms = [
        'H', 'N', 'O', 'C', 'P', 'S', 'F', 'Br', 'Cl', 'I', 'Fe', 'Zn', 'Mg',
        'Na', 'Mn', 'Ca', 'Co', 'Ni', 'Se', 'Cu', 'Cd', 'Hg', 'K'
    ]
    atom_total_degrees = list(range(5))
    atom_formal_charges = [-1, 0, 1]
    atom_implicit_valence = list(range(4))
    atom_explicit_valence = list(range(8))
    atom_concat_featurizer = ConcatFeaturizer([
        partial(atom_type_one_hot, allowable_set=atoms),
        partial(atom_total_degree_one_hot, allowable_set=atom_total_degrees),
        partial(atom_formal_charge_one_hot, allowable_set=atom_formal_charges),
        atom_is_aromatic,
        partial(atom_implicit_valence_one_hot,
                allowable_set=atom_implicit_valence),
        partial(atom_explicit_valence_one_hot,
                allowable_set=atom_explicit_valence)
    ])
    PN_atom_featurizer = BaseAtomFeaturizer({'h': atom_concat_featurizer})

    # Bond featurizer for stage 1
    bond_concat_featurizer = ConcatFeaturizer(
        [bond_type_one_hot, bond_is_in_ring])
    PN_bond_featurizer = BaseBondFeaturizer({'e': bond_concat_featurizer})

    # construct graphs for stage 1
    ligand_bigraph = mol_to_bigraph(
        ligand_mol,
        add_self_loop=False,
        node_featurizer=PN_atom_featurizer,
        edge_featurizer=PN_bond_featurizer,
        canonical_atom_order=False)  # Keep the original atomic order)
    protein_bigraph = mol_to_bigraph(protein_mol,
                                     add_self_loop=False,
                                     node_featurizer=PN_atom_featurizer,
                                     edge_featurizer=PN_bond_featurizer,
                                     canonical_atom_order=False)
    complex_bigraph = batch([ligand_bigraph, protein_bigraph])

    # Construct knn graphs for stage 2
    complex_coordinates = np.concatenate(
        [ligand_coordinates, protein_coordinates])
    complex_srcs, complex_dsts, complex_dists = k_nearest_neighbors(
        complex_coordinates, distance_bins[-1], max_num_neighbors)
    complex_srcs = np.array(complex_srcs)
    complex_dsts = np.array(complex_dsts)
    complex_dists = np.array(complex_dists)

    complex_knn_graph = graph((complex_srcs, complex_dsts),
                              num_nodes=len(complex_coordinates))
    d_features = np.digitize(complex_dists, bins=distance_bins, right=True)
    d_one_hot = int_2_one_hot(d_features)

    # add bond types and bonds (from bigraph) to stage 2
    u, v = complex_bigraph.edges()
    complex_knn_graph.add_edges(u.to(F.int64), v.to(F.int64))
    n_d, f_d = d_one_hot.shape
    n_e, f_e = complex_bigraph.edata['e'].shape
    complex_knn_graph.edata['e'] = F.zerocopy_from_numpy(
        np.block([[d_one_hot, np.zeros((n_d, f_e))],
                  [np.zeros((n_e, f_d)),
                   np.array(complex_bigraph.edata['e'])]]).astype(np.long))
    return complex_bigraph, complex_knn_graph
Exemple #15
0
    def _pre_process(self, smiles_to_graph, node_featurizer, edge_featurizer,
                     load, log_every, init_mask, n_jobs, error_log):
        """Pre-process the dataset

        * Convert molecules from smiles format into DGLGraphs
          and featurize their atoms
        * Set missing labels to be 0 and use a binary masking
          matrix to mask them

        Parameters
        ----------
        smiles_to_graph : callable, SMILES -> DGLGraph
            Function for converting a SMILES (str) into a DGLGraph.
        node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
            Featurization for nodes like atoms in a molecule, which can be used to update
            ndata for a DGLGraph.
        edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
            Featurization for edges like bonds in a molecule, which can be used to update
            edata for a DGLGraph.
        load : bool
            Whether to load the previously pre-processed dataset or pre-process from scratch.
            ``load`` should be False when we want to try different graph construction and
            featurization methods and need to preprocess from scratch. Default to True.
        log_every : bool
            Print a message every time ``log_every`` molecules are processed. It only comes
            into effect when :attr:`n_jobs` is greater than 1.
        init_mask : bool
            Whether to initialize a binary mask indicating the existence of labels.
        n_jobs : int
            Degree of parallelism for pre processing. Default to 1.
        error_log : str
            Path to a CSV file of molecules that RDKit failed to parse. If not specified,
            the molecules will not be recorded.
        """
        if os.path.exists(self.cache_file_path) and load:
            # DGLGraphs have been constructed before, reload them
            print('Loading previously saved dgl graphs...')
            self.graphs, label_dict = load_graphs(self.cache_file_path)
            self.labels = label_dict['labels']
            if init_mask:
                self.mask = label_dict['mask']
            self.valid_ids = label_dict['valid_ids'].tolist()
        else:
            print('Processing dgl graphs from scratch...')
            if n_jobs > 1:
                self.graphs = pmap(smiles_to_graph,
                                   self.smiles,
                                   node_featurizer=node_featurizer,
                                   edge_featurizer=edge_featurizer,
                                   n_jobs=n_jobs)
            else:
                self.graphs = []
                for i, s in enumerate(self.smiles):
                    if (i + 1) % log_every == 0:
                        print('Processing molecule {:d}/{:d}'.format(i+1, len(self)))
                    self.graphs.append(smiles_to_graph(s, node_featurizer=node_featurizer,
                                                       edge_featurizer=edge_featurizer))

            # Keep only valid molecules
            self.valid_ids = []
            graphs = []
            failed_mols = []
            for i, g in enumerate(self.graphs):
                if g is not None:
                    self.valid_ids.append(i)
                    graphs.append(g)
                else:
                    failed_mols.append((i, self.smiles[i]))

            if error_log is not None:
                if len(failed_mols) > 0:
                    failed_ids, failed_smis = map(list, zip(*failed_mols))
                else:
                    failed_ids, failed_smis = [], []
                df = pd.DataFrame({'raw_id': failed_ids, 'smiles': failed_smis})
                df.to_csv(error_log, index=False)

            self.graphs = graphs
            _label_values = self.df[self.task_names].values
            # np.nan_to_num will also turn inf into a very large number
            self.labels = F.zerocopy_from_numpy(
                np.nan_to_num(_label_values).astype(np.float32))[self.valid_ids]
            valid_ids = torch.tensor(self.valid_ids)
            if init_mask:
                self.mask = F.zerocopy_from_numpy(
                    (~np.isnan(_label_values)).astype(np.float32))[self.valid_ids]
                save_graphs(self.cache_file_path, self.graphs,
                            labels={'labels': self.labels, 'mask': self.mask,
                                    'valid_ids': valid_ids})
            else:
                self.mask = None
                save_graphs(self.cache_file_path, self.graphs,
                            labels={'labels': self.labels, 'valid_ids': valid_ids})

        self.smiles = [self.smiles[i] for i in self.valid_ids]
def ACNN_graph_construction_and_featurization(ligand_mol,
                                              protein_mol,
                                              ligand_coordinates,
                                              protein_coordinates,
                                              max_num_ligand_atoms=None,
                                              max_num_protein_atoms=None,
                                              neighbor_cutoff=12.,
                                              max_num_neighbors=12,
                                              strip_hydrogens=False):
    """Graph construction and featurization for `Atomic Convolutional Networks for
    Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__.

    Parameters
    ----------
    ligand_mol : rdkit.Chem.rdchem.Mol
        RDKit molecule instance.
    protein_mol : rdkit.Chem.rdchem.Mol
        RDKit molecule instance.
    ligand_coordinates : Float Tensor of shape (V1, 3)
        Atom coordinates in a ligand.
    protein_coordinates : Float Tensor of shape (V2, 3)
        Atom coordinates in a protein.
    max_num_ligand_atoms : int or None
        Maximum number of atoms in ligands for zero padding, which should be no smaller than
        ligand_mol.GetNumAtoms() if not None. If None, no zero padding will be performed.
        Default to None.
    max_num_protein_atoms : int or None
        Maximum number of atoms in proteins for zero padding, which should be no smaller than
        protein_mol.GetNumAtoms() if not None. If None, no zero padding will be performed.
        Default to None.
    neighbor_cutoff : float
        Distance cutoff to define 'neighboring'. Default to 12.
    max_num_neighbors : int
        Maximum number of neighbors allowed for each atom. Default to 12.
    strip_hydrogens : bool
        Whether to exclude hydrogen atoms. Default to False.
    """
    assert ligand_coordinates is not None, 'Expect ligand_coordinates to be provided.'
    assert protein_coordinates is not None, 'Expect protein_coordinates to be provided.'
    if max_num_ligand_atoms is not None:
        assert max_num_ligand_atoms >= ligand_mol.GetNumAtoms(), \
            'Expect max_num_ligand_atoms to be no smaller than ligand_mol.GetNumAtoms(), ' \
            'got {:d} and {:d}'.format(max_num_ligand_atoms, ligand_mol.GetNumAtoms())
    if max_num_protein_atoms is not None:
        assert max_num_protein_atoms >= protein_mol.GetNumAtoms(), \
            'Expect max_num_protein_atoms to be no smaller than protein_mol.GetNumAtoms(), ' \
            'got {:d} and {:d}'.format(max_num_protein_atoms, protein_mol.GetNumAtoms())

    if strip_hydrogens:
        # Remove hydrogen atoms and their corresponding coordinates
        ligand_atom_indices_left = filter_out_hydrogens(ligand_mol)
        protein_atom_indices_left = filter_out_hydrogens(protein_mol)
        ligand_coordinates = ligand_coordinates.take(ligand_atom_indices_left,
                                                     axis=0)
        protein_coordinates = protein_coordinates.take(
            protein_atom_indices_left, axis=0)
    else:
        ligand_atom_indices_left = list(range(ligand_mol.GetNumAtoms()))
        protein_atom_indices_left = list(range(protein_mol.GetNumAtoms()))

    # Compute number of nodes for each type
    if max_num_ligand_atoms is None:
        num_ligand_atoms = len(ligand_atom_indices_left)
    else:
        num_ligand_atoms = max_num_ligand_atoms

    if max_num_protein_atoms is None:
        num_protein_atoms = len(protein_atom_indices_left)
    else:
        num_protein_atoms = max_num_protein_atoms

    data_dict = dict()
    num_nodes_dict = dict()

    # graph data for atoms in the ligand
    ligand_srcs, ligand_dsts, ligand_dists = k_nearest_neighbors(
        ligand_coordinates, neighbor_cutoff, max_num_neighbors)
    data_dict[('ligand_atom', 'ligand', 'ligand_atom')] = (ligand_srcs,
                                                           ligand_dsts)
    num_nodes_dict['ligand_atom'] = num_ligand_atoms

    # graph data for atoms in the protein
    protein_srcs, protein_dsts, protein_dists = k_nearest_neighbors(
        protein_coordinates, neighbor_cutoff, max_num_neighbors)
    data_dict[('protein_atom', 'protein', 'protein_atom')] = (protein_srcs,
                                                              protein_dsts)
    num_nodes_dict['protein_atom'] = num_protein_atoms

    # 4 graphs for complex representation, including the connection within
    # protein atoms, the connection within ligand atoms and the connection between
    # protein and ligand atoms.
    complex_srcs, complex_dsts, complex_dists = k_nearest_neighbors(
        np.concatenate([ligand_coordinates, protein_coordinates]),
        neighbor_cutoff, max_num_neighbors)
    complex_srcs = np.array(complex_srcs)
    complex_dsts = np.array(complex_dsts)
    complex_dists = np.array(complex_dists)
    offset = num_ligand_atoms

    # ('ligand_atom', 'complex', 'ligand_atom')
    inter_ligand_indices = np.intersect1d((complex_srcs < offset).nonzero()[0],
                                          (complex_dsts < offset).nonzero()[0],
                                          assume_unique=True)
    data_dict[('ligand_atom', 'complex', 'ligand_atom')] = \
        (complex_srcs[inter_ligand_indices].tolist(),
         complex_dsts[inter_ligand_indices].tolist())

    # ('protein_atom', 'complex', 'protein_atom')
    inter_protein_indices = np.intersect1d(
        (complex_srcs >= offset).nonzero()[0],
        (complex_dsts >= offset).nonzero()[0],
        assume_unique=True)
    data_dict[('protein_atom', 'complex', 'protein_atom')] = \
        ((complex_srcs[inter_protein_indices] - offset).tolist(),
         (complex_dsts[inter_protein_indices] - offset).tolist())

    # ('ligand_atom', 'complex', 'protein_atom')
    ligand_protein_indices = np.intersect1d(
        (complex_srcs < offset).nonzero()[0],
        (complex_dsts >= offset).nonzero()[0],
        assume_unique=True)
    data_dict[('ligand_atom', 'complex', 'protein_atom')] = \
        (complex_srcs[ligand_protein_indices].tolist(),
         (complex_dsts[ligand_protein_indices] - offset).tolist())

    # ('protein_atom', 'complex', 'ligand_atom')
    protein_ligand_indices = np.intersect1d(
        (complex_srcs >= offset).nonzero()[0],
        (complex_dsts < offset).nonzero()[0],
        assume_unique=True)
    data_dict[('protein_atom', 'complex', 'ligand_atom')] = \
        ((complex_srcs[protein_ligand_indices] - offset).tolist(),
         complex_dsts[protein_ligand_indices].tolist())

    g = heterograph(data_dict, num_nodes_dict=num_nodes_dict)
    g.edges['ligand'].data['distance'] = F.reshape(
        F.zerocopy_from_numpy(np.array(ligand_dists).astype(np.float32)),
        (-1, 1))
    g.edges['protein'].data['distance'] = F.reshape(
        F.zerocopy_from_numpy(np.array(protein_dists).astype(np.float32)),
        (-1, 1))
    g.edges[('ligand_atom', 'complex', 'ligand_atom')].data['distance'] = \
        F.reshape(F.zerocopy_from_numpy(
            complex_dists[inter_ligand_indices].astype(np.float32)), (-1, 1))
    g.edges[('protein_atom', 'complex', 'protein_atom')].data['distance'] = \
        F.reshape(F.zerocopy_from_numpy(
            complex_dists[inter_protein_indices].astype(np.float32)), (-1, 1))
    g.edges[('ligand_atom', 'complex', 'protein_atom')].data['distance'] = \
        F.reshape(F.zerocopy_from_numpy(
            complex_dists[ligand_protein_indices].astype(np.float32)), (-1, 1))
    g.edges[('protein_atom', 'complex', 'ligand_atom')].data['distance'] = \
        F.reshape(F.zerocopy_from_numpy(
            complex_dists[protein_ligand_indices].astype(np.float32)), (-1, 1))

    # Get atomic numbers for all atoms left and set node features
    ligand_atomic_numbers = np.array(
        get_atomic_numbers(ligand_mol, ligand_atom_indices_left))
    # zero padding
    ligand_atomic_numbers = np.concatenate([
        ligand_atomic_numbers,
        np.zeros(num_ligand_atoms - len(ligand_atom_indices_left))
    ])
    protein_atomic_numbers = np.array(
        get_atomic_numbers(protein_mol, protein_atom_indices_left))
    # zero padding
    protein_atomic_numbers = np.concatenate([
        protein_atomic_numbers,
        np.zeros(num_protein_atoms - len(protein_atom_indices_left))
    ])

    g.nodes['ligand_atom'].data['atomic_number'] = F.reshape(
        F.zerocopy_from_numpy(ligand_atomic_numbers.astype(np.float32)),
        (-1, 1))
    g.nodes['protein_atom'].data['atomic_number'] = F.reshape(
        F.zerocopy_from_numpy(protein_atomic_numbers.astype(np.float32)),
        (-1, 1))

    # Prepare mask indicating the existence of nodes
    ligand_masks = np.zeros((num_ligand_atoms, 1))
    ligand_masks[:len(ligand_atom_indices_left), :] = 1
    g.nodes['ligand_atom'].data['mask'] = F.zerocopy_from_numpy(
        ligand_masks.astype(np.float32))
    protein_masks = np.zeros((num_protein_atoms, 1))
    protein_masks[:len(protein_atom_indices_left), :] = 1
    g.nodes['protein_atom'].data['mask'] = F.zerocopy_from_numpy(
        protein_masks.astype(np.float32))

    return g
Exemple #17
0
    def _preprocess(self, root_path, index_label_file, load_binding_pocket,
                    sanitize, calc_charges, remove_hs, use_conformation,
                    construct_graph_and_featurize, zero_padding, num_processes):
        """Preprocess the dataset.

        The pre-processing proceeds as follows:

        1. Load the dataset
        2. Clean the dataset and filter out invalid pairs
        3. Construct graphs
        4. Prepare node and edge features

        Parameters
        ----------
        root_path : str
            Root path for molecule files.
        index_label_file : str
            Path to the index file for the dataset.
        load_binding_pocket : bool
            Whether to load binding pockets or full proteins.
        sanitize : bool
            Whether sanitization is performed in initializing RDKit molecule instances. See
            https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
        calc_charges : bool
            Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
            ``sanitize`` to be True.
        remove_hs : bool
            Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
            slow for large molecules.
        use_conformation : bool
            Whether we need to extract molecular conformation from proteins and ligands.
        construct_graph_and_featurize : callable
            Construct a DGLHeteroGraph for the use of GNNs. Mapping self.ligand_mols[i],
            self.protein_mols[i], self.ligand_coordinates[i] and self.protein_coordinates[i]
            to a DGLHeteroGraph. Default to :func:`ACNN_graph_construction_and_featurization`.
        zero_padding : bool
            Whether to perform zero padding. While DGL does not necessarily require zero padding,
            pooling operations for variable length inputs can introduce stochastic behaviour, which
            is not desired for sensitive scenarios.
        num_processes : int or None
            Number of worker processes to use. If None,
            then we will use the number of CPUs in the system.
        """
        contents = []
        with open(index_label_file, 'r') as f:
            for line in f.readlines():
                if line[0] != "#":
                    splitted_elements = line.split()
                    if len(splitted_elements) == 8:
                        # Ignore "//"
                        contents.append(splitted_elements[:5] + splitted_elements[6:])
                    else:
                        print('Incorrect data format.')
                        print(splitted_elements)
        self.df = pd.DataFrame(contents, columns=(
            'PDB_code', 'resolution', 'release_year',
            '-logKd/Ki', 'Kd/Ki', 'reference', 'ligand_name'))
        pdbs = self.df['PDB_code'].tolist()

        self.ligand_files = [os.path.join(
            root_path, 'v2015', pdb, '{}_ligand.sdf'.format(pdb)) for pdb in pdbs]
        if load_binding_pocket:
            self.protein_files = [os.path.join(
                root_path, 'v2015', pdb, '{}_pocket.pdb'.format(pdb)) for pdb in pdbs]
        else:
            self.protein_files = [os.path.join(
                root_path, 'v2015', pdb, '{}_protein.pdb'.format(pdb)) for pdb in pdbs]

        num_processes = min(num_processes, len(pdbs))

        print('Loading ligands...')
        ligands_loaded = multiprocess_load_molecules(self.ligand_files,
                                                     sanitize=sanitize,
                                                     calc_charges=calc_charges,
                                                     remove_hs=remove_hs,
                                                     use_conformation=use_conformation,
                                                     num_processes=num_processes)

        print('Loading proteins...')
        proteins_loaded = multiprocess_load_molecules(self.protein_files,
                                                      sanitize=sanitize,
                                                      calc_charges=calc_charges,
                                                      remove_hs=remove_hs,
                                                      use_conformation=use_conformation,
                                                      num_processes=num_processes)

        self._filter_out_invalid(ligands_loaded, proteins_loaded, use_conformation)
        self.df = self.df.iloc[self.indices]
        self.labels = F.zerocopy_from_numpy(self.df[self.task_names].values.astype(np.float32))
        print('Finished cleaning the dataset, '
              'got {:d}/{:d} valid pairs'.format(len(self), len(pdbs)))

        # Prepare zero padding
        if zero_padding:
            max_num_ligand_atoms = 0
            max_num_protein_atoms = 0
            for i in range(len(self)):
                max_num_ligand_atoms = max(
                    max_num_ligand_atoms, self.ligand_mols[i].GetNumAtoms())
                max_num_protein_atoms = max(
                    max_num_protein_atoms, self.protein_mols[i].GetNumAtoms())
        else:
            max_num_ligand_atoms = None
            max_num_protein_atoms = None

        print('Start constructing graphs and featurizing them.')
        self.graphs = []
        for i in range(len(self)):
            print('Constructing and featurizing datapoint {:d}/{:d}'.format(i+1, len(self)))
            self.graphs.append(construct_graph_and_featurize(
                self.ligand_mols[i], self.protein_mols[i],
                self.ligand_coordinates[i], self.protein_coordinates[i],
                max_num_ligand_atoms, max_num_protein_atoms))
Exemple #18
0
    def _preprocess(self, load_binding_pocket, sanitize, calc_charges,
                    remove_hs, use_conformation, construct_graph_and_featurize,
                    zero_padding, num_processes):
        """Preprocess the dataset.

        The pre-processing proceeds as follows:

        1. Load the dataset
        2. Clean the dataset and filter out invalid pairs
        3. Construct graphs
        4. Prepare node and edge features

        Parameters
        ----------
        load_binding_pocket : bool
            Whether to load binding pockets or full proteins.
        sanitize : bool
            Whether sanitization is performed in initializing RDKit molecule instances. See
            https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
        calc_charges : bool
            Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
            ``sanitize`` to be True.
        remove_hs : bool
            Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
            slow for large molecules.
        use_conformation : bool
            Whether we need to extract molecular conformation from proteins and ligands.
        construct_graph_and_featurize : callable
            Construct a DGLHeteroGraph for the use of GNNs. Mapping self.ligand_mols[i],
            self.protein_mols[i], self.ligand_coordinates[i] and self.protein_coordinates[i]
            to a DGLHeteroGraph. Default to :func:`ACNN_graph_construction_and_featurization`.
        zero_padding : bool
            Whether to perform zero padding. While DGL does not necessarily require zero padding,
            pooling operations for variable length inputs can introduce stochastic behaviour, which
            is not desired for sensitive scenarios.
        num_processes : int or None
            Number of worker processes to use. If None,
            then we will use the number of CPUs in the system.
        """
        if num_processes is None:
            num_processes = multiprocessing.cpu_count()
        num_processes = min(num_processes, len(self.df))

        print('Loading ligands...')
        ligands_loaded = multiprocess_load_molecules(
            self.ligand_files,
            sanitize=sanitize,
            calc_charges=calc_charges,
            remove_hs=remove_hs,
            use_conformation=use_conformation,
            num_processes=num_processes)

        print('Loading proteins...')
        proteins_loaded = multiprocess_load_molecules(
            self.protein_files,
            sanitize=sanitize,
            calc_charges=calc_charges,
            remove_hs=remove_hs,
            use_conformation=use_conformation,
            num_processes=num_processes)

        self._filter_out_invalid(ligands_loaded, proteins_loaded,
                                 use_conformation)
        self.df = self.df.iloc[self.indices]
        self.labels = F.zerocopy_from_numpy(
            self.df[self.task_names].values.astype(np.float32))
        print('Finished cleaning the dataset, '
              'got {:d}/{:d} valid pairs'.format(len(self), len(self.df)))

        # Prepare zero padding
        if zero_padding:
            max_num_ligand_atoms = 0
            max_num_protein_atoms = 0
            for i in range(len(self)):
                max_num_ligand_atoms = max(max_num_ligand_atoms,
                                           self.ligand_mols[i].GetNumAtoms())
                max_num_protein_atoms = max(max_num_protein_atoms,
                                            self.protein_mols[i].GetNumAtoms())
        else:
            max_num_ligand_atoms = None
            max_num_protein_atoms = None

        construct_graph_and_featurize = partial(
            construct_graph_and_featurize,
            max_num_ligand_atoms=max_num_ligand_atoms,
            max_num_protein_atoms=max_num_protein_atoms)

        print('Start constructing graphs and featurizing them.')
        num_mols = len(self)
        # self.graphs = []
        # for i in range(num_mols):
        #     print('Constructing and featurizing datapoint {:d}/{:d}'.format(i+1, num_mols))
        #     self.graphs.append(construct_graph_and_featurize(
        #         self.ligand_mols[i], self.protein_mols[i],
        #         self.ligand_coordinates[i], self.protein_coordinates[i],))

        # construct graphs with multiprocessing
        pool = multiprocessing.Pool(processes=num_processes)
        self.graphs = pool.starmap(
            construct_graph_and_featurize,
            zip(self.ligand_mols, self.protein_mols, self.ligand_coordinates,
                self.protein_coordinates))
        print(f'Done constructing {len(self.graphs)} graphs.')
def XYZ_graph_construction_and_featurization( protein_mol,
                                              protein_coordinates,
                                              max_num_protein_atoms=None,
                                              neighbor_cutoff=12.,
                                              max_num_neighbors=12,
                                              strip_hydrogens=False):
    """Graph construction and featurization for `Atomic Convolutional Networks for
    Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__.

    Parameters
    ----------

    protein_mol : rdkit.Chem.rdchem.Mol
        RDKit molecule instance.
    protein_coordinates : Float Tensor of shape (V2, 3)
        Atom coordinates in a protein.
    max_num_protein_atoms : int or None
        Maximum number of atoms in proteins for zero padding.
        If None, no zero padding will be performed. Default to None.
    neighbor_cutoff : float
        Distance cutoff to define 'neighboring'. Default to 12.
    max_num_neighbors : int
        Maximum number of neighbors allowed for each atom. Default to 12.
    strip_hydrogens : bool
        Whether to exclude hydrogen atoms. Default to False.
    """
    assert protein_coordinates is not None, 'Expect protein_coordinates to be provided.'

    if strip_hydrogens:
        # Remove hydrogen atoms and their corresponding coordinates
        protein_atom_indices_left = filter_out_hydrogens(protein_mol)
        protein_coordinates = protein_coordinates.take(protein_atom_indices_left, axis=0)
    else:
        protein_atom_indices_left = list(range(protein_mol.n_atoms ))

    # Compute number of nodes for each type


    if max_num_protein_atoms is None:
        num_protein_atoms = len(protein_atom_indices_left)
    else:
        num_protein_atoms = max_num_protein_atoms

    # Construct graph for atoms in the protein
    protein_srcs, protein_dsts, protein_dists = k_nearest_neighbors(
        protein_coordinates, neighbor_cutoff, max_num_neighbors)
    protein_graph = graph((protein_srcs, protein_dsts),
                          'protein_atom', 'protein', num_protein_atoms)
    protein_graph.edata['distance'] = F.reshape(F.zerocopy_from_numpy(
        np.array(protein_dists).astype(np.float32)), (-1, 1))

    # Construct 4 graphs for complex representation, including the connection within
    # protein atoms, the connection within ligand atoms and the connection between
    # protein and ligand atoms.


    # Merge the graphs
    g = protein_graph
    protein_atomic_numbers = np.array(get_atomic_numbers(protein_mol, protein_atom_indices_left))
    # zero padding
    protein_atomic_numbers = np.concatenate([
        protein_atomic_numbers, np.zeros(num_protein_atoms - len(protein_atom_indices_left))])


    g.nodes['protein_atom'].data['atomic_number'] = F.reshape(F.zerocopy_from_numpy(
        protein_atomic_numbers.astype(np.float32)), (-1, 1))

    # Prepare mask indicating the existence of nodes

    protein_masks = np.zeros((num_protein_atoms, 1))
    protein_masks[:len(protein_atom_indices_left), :] = 1
    g.nodes['protein_atom'].data['mask'] = F.zerocopy_from_numpy(
        protein_masks.astype(np.float32))

    return g
Exemple #20
0
def potentialNet_graph_construction_featurization(
        ligand_mol,
        protein_mol,
        ligand_coordinates,
        protein_coordinates,
        max_num_ligand_atoms=None,
        max_num_protein_atoms=None,
        max_num_neighbors=4,
        distance_bins=[1.5, 2.5, 3.5, 4.5],
        strip_hydrogens=False):
    """Graph construction and featurization for `PotentialNet for Molecular Property Prediction
     <https://pubs.acs.org/doi/10.1021/acscentsci.8b00507>`__.

    Parameters
    ----------
    ligand_mol : rdkit.Chem.rdchem.Mol
        RDKit molecule instance.
    protein_mol : rdkit.Chem.rdchem.Mol
        RDKit molecule instance.
    ligand_coordinates : Float Tensor of shape (V1, 3)
        Atom coordinates in a ligand.
    protein_coordinates : Float Tensor of shape (V2, 3)
        Atom coordinates in a protein.
    max_num_ligand_atoms : int or None
        Maximum number of atoms in ligands for zero padding, which should be no smaller than
        ligand_mol.GetNumAtoms() if not None. If None, no zero padding will be performed.
        Default to None.
    max_num_protein_atoms : int or None
        Maximum number of atoms in proteins for zero padding, which should be no smaller than
        protein_mol.GetNumAtoms() if not None. If None, no zero padding will be performed.
        Default to None.
    distance_bins : list
        Sequence of distance edges to determine the edge types.
    max_num_neighbors : int
        Maximum number of neighbors allowed for each atom. Default to 12.
    strip_hydrogens : bool
        Whether to exclude hydrogen atoms. Default to False.

    Returns
    -------
    complex_bigraph : DGLGraph
        Bigraph with ligand and protein (pocket) combined and canonical features extracted.
        The atom features are stored as DGLGraph.ndata['h'].
        The edge types are stored as DGLGraph.edata['e'].
        The bigraphs of ligand and protein are batched together as one complex graph.
    complex_knn_graph : DGLGraph
        K-nearest-neighbor graph with ligand and protein (pocket) combined and edge features extracted based on distances.
        The edge types are stored as DGLGraph.edata['e'].
        The knn-graphs of ligand and protein are batched together as one complex graph.

    """

    assert ligand_coordinates is not None, 'Expect ligand_coordinates to be provided.'
    assert protein_coordinates is not None, 'Expect protein_coordinates to be provided.'
    if max_num_ligand_atoms is not None:
        assert max_num_ligand_atoms >= ligand_mol.GetNumAtoms(), \
            'Expect max_num_ligand_atoms to be no smaller than ligand_mol.GetNumAtoms(), ' \
            'got {:d} and {:d}'.format(max_num_ligand_atoms, ligand_mol.GetNumAtoms())
    if max_num_protein_atoms is not None:
        assert max_num_protein_atoms >= protein_mol.GetNumAtoms(), \
            'Expect max_num_protein_atoms to be no smaller than protein_mol.GetNumAtoms(), ' \
            'got {:d} and {:d}'.format(max_num_protein_atoms, protein_mol.GetNumAtoms())

    if strip_hydrogens:
        # Remove hydrogen atoms and their corresponding coordinates
        ligand_atom_indices_left = filter_out_hydrogens(ligand_mol)
        protein_atom_indices_left = filter_out_hydrogens(protein_mol)
        ligand_coordinates = ligand_coordinates.take(ligand_atom_indices_left,
                                                     axis=0)
        protein_coordinates = protein_coordinates.take(
            protein_atom_indices_left, axis=0)
    else:
        ligand_atom_indices_left = list(range(ligand_mol.GetNumAtoms()))
        protein_atom_indices_left = list(range(protein_mol.GetNumAtoms()))

    # Compute number of nodes for each type
    if max_num_ligand_atoms is None:
        num_ligand_atoms = len(ligand_atom_indices_left)
    else:
        num_ligand_atoms = max_num_ligand_atoms

    if max_num_protein_atoms is None:
        num_protein_atoms = len(protein_atom_indices_left)
    else:
        num_protein_atoms = max_num_protein_atoms

    # Construct bigraph for stage 1
    node_featurizer = CanonicalAtomFeaturizer(atom_data_field='h')
    edge_featurizer = CanonicalBondFeaturizer(bond_data_field='e')
    ligand_bigraph = mol_to_bigraph(
        ligand_mol,
        add_self_loop=False,
        node_featurizer=node_featurizer,
        edge_featurizer=edge_featurizer,
        canonical_atom_order=False,  # Keep the original atomic order
    )
    protein_bigraph = mol_to_bigraph(protein_mol,
                                     add_self_loop=False,
                                     node_featurizer=node_featurizer,
                                     edge_featurizer=edge_featurizer,
                                     canonical_atom_order=False)

    complex_bigraph = batch([ligand_bigraph, protein_bigraph])
    # remove features that never appear
    zero_h_cols = [
        5, 13, 14, 16, 17, 19, 20, 21, 22, 23, 24, 27, 30, 31, 33, 36, 38, 39,
        40, 42, 50, 51, 52, 53, 58, 59, 60, 62, 63, 64, 65, 66, 67, 73
    ]
    zero_e_cols = [4, 5, 7, 8, 9, 10, 11]
    complex_bigraph.ndata['h'] = np.delete(complex_bigraph.ndata['h'],
                                           zero_h_cols,
                                           axis=1)
    complex_bigraph.edata['e'] = np.delete(complex_bigraph.edata['e'],
                                           zero_e_cols,
                                           axis=1)  # 5 edge types remain

    # Construct knn grpah for stage 2
    complex_coordinates = np.concatenate(
        [ligand_coordinates, protein_coordinates])
    complex_srcs, complex_dsts, complex_dists = k_nearest_neighbors(
        complex_coordinates, distance_bins[-1], max_num_neighbors)
    complex_srcs = np.array(complex_srcs)
    complex_dsts = np.array(complex_dsts)
    complex_dists = np.array(complex_dists)

    complex_knn_graph = graph([])
    complex_knn_graph.add_nodes(len(complex_coordinates))
    complex_knn_graph.add_edges(complex_srcs, complex_dsts)
    d_features = np.digitize(complex_dists, bins=distance_bins, right=True)
    d_one_hot = int_2_one_hot(d_features)

    # add bond types and bonds (from bigraph) to stage 2
    u, v = complex_bigraph.edges()
    complex_knn_graph.add_edges(u.to(F.int64), v.to(F.int64))
    n_d, f_d = d_one_hot.shape
    n_e, f_e = complex_bigraph.edata['e'].shape
    complex_knn_graph.edata['e'] = F.zerocopy_from_numpy(
        np.block([[d_one_hot, np.zeros((n_d, f_e))],
                  [np.zeros((n_e, f_d)),
                   np.array(complex_bigraph.edata['e'])]]).astype(np.long))
    return complex_bigraph, complex_knn_graph