Ejemplo n.º 1
0
    def _featurize(self, mol):
        """Encodes mol as a WeaveMol object."""
        # Atom features
        idx_nodes = [(a.GetIdx(), atom_features(a, explicit_H=self.explicit_H))
                     for a in mol.GetAtoms()]
        idx_nodes.sort()  # Sort by ind to ensure same order as rd_kit
        idx, nodes = list(zip(*idx_nodes))

        # Stack nodes into an array
        nodes = np.vstack(nodes)

        # Get bond lists
        edge_list = {}
        for b in mol.GetBonds():
            edge_list[tuple(sorted([b.GetBeginAtomIdx(),
                                    b.GetEndAtomIdx()]))] = bond_features(b)

        # Get canonical adjacency list
        canon_adj_list = [[] for mol_id in range(len(nodes))]
        for edge in edge_list.keys():
            canon_adj_list[edge[0]].append(edge[1])
            canon_adj_list[edge[1]].append(edge[0])

        # Calculate pair features
        pairs = pair_features(mol,
                              edge_list,
                              canon_adj_list,
                              bt_len=6,
                              graph_distance=self.graph_distance)

        return WeaveMol(nodes, pairs)
Ejemplo n.º 2
0
    def _featurize(self, mol):
        """Encodes mol as a WeaveMol object."""
        # Atom features
        idx_nodes = [(a.GetIdx(),
                      atom_features(a,
                                    explicit_H=self.explicit_H,
                                    use_chirality=self.use_chirality))
                     for a in mol.GetAtoms()]
        idx_nodes.sort()  # Sort by ind to ensure same order as rd_kit
        idx, nodes = list(zip(*idx_nodes))

        # Stack nodes into an array
        nodes = np.vstack(nodes)

        # Get bond lists
        bond_features_map = {}
        for b in mol.GetBonds():
            bond_features_map[tuple(
                sorted([b.GetBeginAtomIdx(),
                        b.GetEndAtomIdx()
                        ]))] = bond_features(b,
                                             use_chirality=self.use_chirality)

        # Get canonical adjacency list
        bond_adj_list = [[] for mol_id in range(len(nodes))]
        for bond in bond_features_map.keys():
            bond_adj_list[bond[0]].append(bond[1])
            bond_adj_list[bond[1]].append(bond[0])

        # Calculate pair features
        pairs, pair_edges = pair_features(
            mol,
            bond_features_map,
            bond_adj_list,
            bt_len=self.bt_len,
            graph_distance=self.graph_distance,
            max_pair_distance=self.max_pair_distance)

        return WeaveMol(nodes, pairs, pair_edges)